IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileWriter

struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_rank: Int, c_tile_shape: IndexList[c_rank], c_desc_shape: IndexList[c_rank], //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[Int(3)], mma_shape: IndexList[Int(3)], opc: OutputPipelineConfig, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, num_output_warps: Int, elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, batched: Bool = False, problem_n: Int = Int(0)]

Output tile writer for SM100 matmul epilogue.

Stores pointer to TMA descriptor. SMEM tiles passed per-call.

Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.

The opc (OutputPipelineConfig) parameter must match the config used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.

Fields​

  • ​c_tma_op (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOpPtr):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_tile_layout​

comptime accum_tile_layout = Layout.row_major(block_tile_shape[Int(0)], TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN)

AccumTmemArray​

comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].accum_tile_layout, EpilogueConfig.create(MMA_M=mma_shape[Int(0)], MMA_N=mma_shape[Int(1)], stageN=c_smem_dim0 if transpose_c else c_smem_dim1, cta_group=opc.cta_group, transpose_c=transpose_c, BM=block_tile_shape[Int(0)], BN=block_tile_shape[Int(1)]).num_stages, cta_group=opc.cta_group]

bits​

comptime bits = 256

BM​

comptime BM = block_tile_shape[Int(0)]

BN​

comptime BN = block_tile_shape[Int(1)]

c_smem_layout​

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

cta_group​

comptime cta_group = opc.cta_group

CTileArray​

comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]

data_paths​

comptime data_paths = 16

epc​

comptime epc = EpilogueConfig.create(MMA_M=mma_shape[Int(0)], MMA_N=mma_shape[Int(1)], stageN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN, cta_group=opc.cta_group, transpose_c=transpose_c, BM=block_tile_shape[Int(0)], BN=block_tile_shape[Int(1)])

epilogue_dtype​

comptime epilogue_dtype = TileWriter.get_epilogue_dtype()

fragment_size​

comptime fragment_size = (Int(128) // _resolve_warp_size())

is_lower_frag_required​

comptime is_lower_frag_required = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.is_lower_frag_required

MMA_M​

comptime MMA_M = mma_shape[Int(0)]

MMA_N​

comptime MMA_N = mma_shape[Int(1)]

N_dim​

comptime N_dim = Int(0) if transpose_c else Int(1)

num_accum_pipeline_stages​

comptime num_accum_pipeline_stages = opc.num_stages

num_stages​

comptime num_stages = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.num_stages

rep​

comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN // Int(8))

rep_frag_size​

comptime rep_frag_size = ((Int(128) // _resolve_warp_size()) * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].rep)

Stage​

comptime Stage = OutputStage[opc]

stage_contiguous_size​

comptime stage_contiguous_size = c_smem_dim1

stage_stride_cols​

comptime stage_stride_cols = opc.stage_stride_cols

stageN​

comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1

TmaOp​

comptime TmaOp = TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape]

TmaOpPtr​

comptime TmaOpPtr = Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]

Methods​

__init__​

def __init__(c_tma_op: Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]) -> Self

Initialize with pointer to TMA descriptor.

get_epilogue_dtype​

static def get_epilogue_dtype() -> DType

Returns:

DType

write​

def write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write accumulated results to global memory (2D coords).

write_batched​

def write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)

Write accumulated results to global memory (3D batched coords).

Args:

write_splitk​

def write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], scheduler: TileScheduler, reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with split-K reduction. Only last split writes to GMEM.

write_absolute_with_bounds_check​

def write_absolute_with_bounds_check[c_tensor_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])

Write with absolute coordinates and bounds checking.

For 1D-1D grouped kernels where M coordinate is absolute.

write_with_residual​

def write_with_residual[pipeline_origin: MutOrigin, //, num_src_stages: Int](self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_src_stages], src_pipeline: Pointer[ProducerConsumerPipeline[num_src_stages], pipeline_origin], beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with residual: D = lambda(accum) + beta * C.

Matches the CUTLASS sm100_epilogue_tma_warpspecialized lockstep pattern: the epilogue load warp pre-fetches one source sub-tile per inner epilogue stage into a num_src_stages-deep SMEM pipeline; this method drives one wait_producer / use / consumer_release / step cycle on src_pipeline per inner stage. The buffer index is read from the pipeline's consumer_stage() rather than computed offline, so producer and consumer stay synchronized exactly as in CUTLASS's consumer_wait β†’ copy(sC) β†’ consumer_release per epi sub-tile.

Pipeline per inner stage:

  1. Load accum from TMEM to registers (epilogue dtype).
  2. Apply elementwise_compute_lambda_fn (pre-residual fusion).
  3. Wait for source[k] via src_pipeline.consume(); compute D = accum + beta * C reading from the SMEM buffer at the pipeline's current stage index; release source[k] on context exit.
  4. Apply elementwise_lambda_fn (post-residual, owns the GMEM store) OR stage to output SMEM and TMA-store to GMEM.

Parameters:

  • ​pipeline_origin (MutOrigin): Mutability origin of the source pipeline ref.
  • ​num_src_stages (Int): Number of source SMEM buffers; must equal the epi-load pipeline's stage count in the kernel.

Args:

write_batched_with_tma_epilogue_load​

def write_batched_with_tma_epilogue_load[epi_load_swizzle: TensorMapSwizzle, epilogue_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], epilogue_tile: TileTensor[c_type, epilogue_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])

Write accumulated results with epilogue tensor addition to global memory.

Pipeline: TMEM β†’ Registers β†’ (+epilogue from SMEM) β†’ SMEM β†’ GMEM (TMA).

write_batched_with_1d_bias​

def write_batched_with_1d_bias[epilogue_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], epilogue_tile: TileTensor[c_type, epilogue_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])

Write accumulated results with 1D bias addition to global memory.

Pipeline: TMEM -> Registers -> (+1D bias broadcast from SMEM) -> SMEM -> GMEM (TMA).

The bias SMEM tile is 1Γ—MMA_N loaded via cp.async (linear layout, no swizzle) and then broadcast across all M rows.

write_batched_with_tma_epilogue_load_strips​

def write_batched_with_tma_epilogue_load_strips[epi_load_swizzle: TensorMapSwizzle, num_epi_stages: Int](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], mut epilogue_pipeline: ProducerConsumerPipeline[num_epi_stages], epilogue_tiles_base: UnsafePointer[Scalar[c_type], MutAnyOrigin, address_space=AddressSpace.SHARED], epilogue_tile_elems: Int, tile_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])

Write accumulated results with BMΓ—stageN pipelined epilogue addition.

For non-AB_swapped configs. Each epilogue pipeline stage is one BMΓ—stageN tile. Producer sends tiles in stage-outer / col_wg-inner order; consumer mirrors that structure so each TMEM stage is fully processed (load β†’ add epilogue β†’ write) before advancing to the next.