Mojo struct
TileWriter
struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_rank: Int, c_tile_shape: IndexList[c_rank], c_desc_shape: IndexList[c_rank], //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], opc: OutputPipelineConfig, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, num_output_warps: Int, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, batched: Bool = False, problem_n: Int = 0]
Output tile writer for SM100 matmul epilogue.
Stores pointer to TMA descriptor. SMEM tiles passed per-call.
Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.
The opc (OutputPipelineConfig) parameter must match the config used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.
Fieldsβ
- βc_tma_op (
TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOpPtr):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_tile_layoutβ
comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN)
AccumTmemArrayβ
comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].num_stages, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group]
bitsβ
comptime bits = 256
BMβ
comptime BM = block_tile_shape[0]
BNβ
comptime BN = block_tile_shape[1]
c_smem_layoutβ
comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)
cta_groupβ
comptime cta_group = opc.cta_group
CTileArrayβ
comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]
data_pathsβ
comptime data_paths = 16
epcβ
comptime epc = EpilogueConfig.create(MMA_M=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_M, MMA_N=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_N, stageN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group, transpose_c=transpose_c, BM=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, BN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BN)
epilogue_dtypeβ
comptime epilogue_dtype = TileWriter.get_epilogue_dtype()
fragment_sizeβ
comptime fragment_size = (128 // WARP_SIZE)
is_lower_frag_requiredβ
comptime is_lower_frag_required = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.is_lower_frag_required
MMA_Mβ
comptime MMA_M = mma_shape[0]
MMA_Nβ
comptime MMA_N = mma_shape[1]
N_dimβ
comptime N_dim = 0 if transpose_c else 1
num_accum_pipeline_stagesβ
comptime num_accum_pipeline_stages = opc.num_stages
num_stagesβ
comptime num_stages = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.num_stages
repβ
comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN // 8)
rep_frag_sizeβ
comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].fragment_size * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].rep)
Stageβ
comptime Stage = OutputStage[opc]
stage_contiguous_sizeβ
comptime stage_contiguous_size = c_smem_dim1
stage_stride_colsβ
comptime stage_stride_cols = opc.stage_stride_cols
stageNβ
comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1
TmaOpβ
comptime TmaOp = TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape]
TmaOpPtrβ
comptime TmaOpPtr = Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]
Methodsβ
__init__β
__init__(c_tma_op: Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]) -> Self
Initialize with pointer to TMA descriptor.
get_epilogue_dtypeβ
writeβ
write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write accumulated results to global memory (2D coords).
write_batchedβ
write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)
Write accumulated results to global memory (3D batched coords).
Args:
- βc_tiles (
SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]): TileTensor-based SMEM tile array for C output. - βstage (
OutputStage[opc]): OutputStage with pipeline, index, and TMEM handle. - βtile_coord (
Tuple[UInt32, UInt32, UInt32]): (m_tile, n_tile, batch) coordinates. - βshape (
Tuple[UInt32, UInt32]): (M, N) problem dimensions. - βalpha (
Float32): Tensor scale factor (scalar).
write_splitkβ
write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], scheduler: TileScheduler, reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with split-K reduction. Only last split writes to GMEM.
write_absolute_with_bounds_checkβ
write_absolute_with_bounds_check[c_tensor_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])
Write with absolute coordinates and bounds checking.
For 1D-1D grouped kernels where M coordinate is absolute.
write_with_residualβ
write_with_residual(self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], src_stage_idx: UInt32, beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with residual: D = lambda(accum) + beta * C.
This method extends the standard write() to add a residual term loaded from source tensor C in shared memory. The epilogue load warp pre-fetches C tiles into src_tile before this method is called.
Pipeline:
- Load accum from TMEM to registers
- Apply epilogue lambda (if present)
- Load C fragment from source SMEM
- Compute D = accum + beta * C
- Write D to output SMEM and TMA store to GMEM
Args:
- βout_tiles (
SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]): Output SMEM tile array (for D output). - βstage (
OutputStage[opc]): OutputStage with pipeline, index, and TMEM handle. - βsrc_tile (
SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]): Source C SMEM tile array (TileTensor-based, from epilogue load warp via smem.src_tiles()). - βsrc_stage_idx (
UInt32): Stage index into src_tile (0 or 1 for double-buffer). - βbeta (
Scalar[c_type]): Residual scale factor. - βtile_coord (
Tuple[UInt32, UInt32]): (m_tile, n_tile) coordinates. - βshape (
Tuple[UInt32, UInt32]): (M, N) problem dimensions. - βelect_one_warp (
Bool): Whether this warp is elected for coordination.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!