For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TileWriter
struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_rank: Int, c_tile_shape: IndexList[c_rank], c_desc_shape: IndexList[c_rank], //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[Int(3)], mma_shape: IndexList[Int(3)], opc: OutputPipelineConfig, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, num_output_warps: Int, elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, batched: Bool = False, problem_n: Int = Int(0)]
Output tile writer for SM100 matmul epilogue.
Stores pointer to TMA descriptor. SMEM tiles passed per-call.
Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.
The opc (OutputPipelineConfig) parameter must match the config used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.
Fieldsβ
- βc_tma_op (
TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOpPtr):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_tile_layoutβ
comptime accum_tile_layout = Layout.row_major(block_tile_shape[Int(0)], TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN)
AccumTmemArrayβ
comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].accum_tile_layout, EpilogueConfig.create(MMA_M=mma_shape[Int(0)], MMA_N=mma_shape[Int(1)], stageN=c_smem_dim0 if transpose_c else c_smem_dim1, cta_group=opc.cta_group, transpose_c=transpose_c, BM=block_tile_shape[Int(0)], BN=block_tile_shape[Int(1)]).num_stages, cta_group=opc.cta_group]
bitsβ
comptime bits = 256
BMβ
comptime BM = block_tile_shape[Int(0)]
BNβ
comptime BN = block_tile_shape[Int(1)]
c_smem_layoutβ
comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)
cta_groupβ
comptime cta_group = opc.cta_group
CTileArrayβ
comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]
data_pathsβ
comptime data_paths = 16
epcβ
comptime epc = EpilogueConfig.create(MMA_M=mma_shape[Int(0)], MMA_N=mma_shape[Int(1)], stageN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN, cta_group=opc.cta_group, transpose_c=transpose_c, BM=block_tile_shape[Int(0)], BN=block_tile_shape[Int(1)])
epilogue_dtypeβ
comptime epilogue_dtype = TileWriter.get_epilogue_dtype()
fragment_sizeβ
comptime fragment_size = (Int(128) // _resolve_warp_size())
is_lower_frag_requiredβ
comptime is_lower_frag_required = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.is_lower_frag_required
MMA_Mβ
comptime MMA_M = mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = mma_shape[Int(1)]
N_dimβ
comptime N_dim = Int(0) if transpose_c else Int(1)
num_accum_pipeline_stagesβ
comptime num_accum_pipeline_stages = opc.num_stages
num_stagesβ
comptime num_stages = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.num_stages
repβ
comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN // Int(8))
rep_frag_sizeβ
comptime rep_frag_size = ((Int(128) // _resolve_warp_size()) * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].rep)
Stageβ
comptime Stage = OutputStage[opc]
stage_contiguous_sizeβ
comptime stage_contiguous_size = c_smem_dim1
stage_stride_colsβ
comptime stage_stride_cols = opc.stage_stride_cols
stageNβ
comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1
TmaOpβ
comptime TmaOp = TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape]
TmaOpPtrβ
comptime TmaOpPtr = Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]
Methodsβ
__init__β
def __init__(c_tma_op: Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]) -> Self
Initialize with pointer to TMA descriptor.
get_epilogue_dtypeβ
writeβ
def write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write accumulated results to global memory (2D coords).
write_batchedβ
def write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)
Write accumulated results to global memory (3D batched coords).
Args:
- βc_tiles (
SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]): TileTensor-based SMEM tile array for C output. - βstage (
OutputStage[opc]): OutputStage with pipeline, index, and TMEM handle. - βtile_coord (
Tuple[UInt32, UInt32, UInt32]): (m_tile, n_tile, batch) coordinates. - βshape (
Tuple[UInt32, UInt32]): (M, N) problem dimensions. - βalpha (
Float32): Tensor scale factor (scalar).
write_splitkβ
def write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], scheduler: TileScheduler, reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with split-K reduction. Only last split writes to GMEM.
write_absolute_with_bounds_checkβ
def write_absolute_with_bounds_check[c_tensor_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])
Write with absolute coordinates and bounds checking.
For 1D-1D grouped kernels where M coordinate is absolute.
write_with_residualβ
def write_with_residual[pipeline_origin: MutOrigin, //, num_src_stages: Int](self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_src_stages], src_pipeline: Pointer[ProducerConsumerPipeline[num_src_stages], pipeline_origin], beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with residual: D = lambda(accum) + beta * C.
Matches the CUTLASS sm100_epilogue_tma_warpspecialized lockstep
pattern: the epilogue load warp pre-fetches one source sub-tile per
inner epilogue stage into a num_src_stages-deep SMEM pipeline; this
method drives one wait_producer / use / consumer_release / step
cycle on src_pipeline per inner stage. The buffer index is read from
the pipeline's consumer_stage() rather than computed offline, so
producer and consumer stay synchronized exactly as in CUTLASS's
consumer_wait β copy(sC) β consumer_release per epi sub-tile.
Pipeline per inner stage:
- Load accum from TMEM to registers (epilogue dtype).
- Apply
elementwise_compute_lambda_fn(pre-residual fusion). - Wait for source[k] via
src_pipeline.consume(); computeD = accum + beta * Creading from the SMEM buffer at the pipeline's current stage index; release source[k] on context exit. - Apply
elementwise_lambda_fn(post-residual, owns the GMEM store) OR stage to output SMEM and TMA-store to GMEM.
Parameters:
- βpipeline_origin (
MutOrigin): Mutability origin of the source pipeline ref. - βnum_src_stages (
Int): Number of source SMEM buffers; must equal the epi-load pipeline's stage count in the kernel.
Args:
- βout_tiles (
SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]): Output SMEM tile array (for D output). - βstage (
OutputStage[opc]): OutputStage with pipeline, index, and TMEM handle. - βsrc_tile (
SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_src_stages]): Source C SMEM tile array (num_src_stages buffers). - βsrc_pipeline (
Pointer[ProducerConsumerPipeline[num_src_stages], pipeline_origin]): Pointer to the source producer/consumer pipeline. One acquire/release cycle is driven per inner epilogue stage. - βbeta (
Scalar[c_type]): Residual scale factor. - βtile_coord (
Tuple[UInt32, UInt32]): (m_tile, n_tile) coordinates. - βshape (
Tuple[UInt32, UInt32]): (M, N) problem dimensions. - βelect_one_warp (
Bool): Whether this warp is elected for coordination.
write_batched_with_tma_epilogue_loadβ
def write_batched_with_tma_epilogue_load[epi_load_swizzle: TensorMapSwizzle, epilogue_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], epilogue_tile: TileTensor[c_type, epilogue_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])
Write accumulated results with epilogue tensor addition to global memory.
Pipeline: TMEM β Registers β (+epilogue from SMEM) β SMEM β GMEM (TMA).
write_batched_with_1d_biasβ
def write_batched_with_1d_bias[epilogue_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], epilogue_tile: TileTensor[c_type, epilogue_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])
Write accumulated results with 1D bias addition to global memory.
Pipeline: TMEM -> Registers -> (+1D bias broadcast from SMEM) -> SMEM -> GMEM (TMA).
The bias SMEM tile is 1ΓMMA_N loaded via cp.async (linear layout, no swizzle) and then broadcast across all M rows.
write_batched_with_tma_epilogue_load_stripsβ
def write_batched_with_tma_epilogue_load_strips[epi_load_swizzle: TensorMapSwizzle, num_epi_stages: Int](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], mut epilogue_pipeline: ProducerConsumerPipeline[num_epi_stages], epilogue_tiles_base: UnsafePointer[Scalar[c_type], MutAnyOrigin, address_space=AddressSpace.SHARED], epilogue_tile_elems: Int, tile_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])
Write accumulated results with BMΓstageN pipelined epilogue addition.
For non-AB_swapped configs. Each epilogue pipeline stage is one BMΓstageN tile. Producer sends tiles in stage-outer / col_wg-inner order; consumer mirrors that structure so each TMEM stage is fully processed (load β add epilogue β write) before advancing to the next.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!