Mojo struct
TileWriter
@register_passable(trivial)
struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_layout: Layout, c_desc_layout: Layout, //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, num_accum_pipeline_stages: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, stage_stride_cols: Int, num_output_warps: Int, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, batched: Bool = False]
Output tile writer for SM100 matmul epilogue.
Stores pointer to TMA descriptor. SMEM tiles passed per-call.
Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.
The stage_stride_cols parameter must match the value used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.
Fields
- c_tma_op (
TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].TmaOpPtr):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
accum_tile_layout
comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN)
AccumTmemArray
comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].num_stages, cta_group=cta_group]
bits
comptime bits = 256
BM
comptime BM = block_tile_shape.__getitem__[Int](0)
BN
comptime BN = block_tile_shape.__getitem__[Int](1)
c_smem_layout
comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)
cg1_num_stages
comptime cg1_num_stages = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN)
cg2_num_stages
comptime cg2_num_stages = (mma_shape.__getitem__[Int](1) // c_smem_dim0 if transpose_c else c_smem_dim1) if (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_M == 256)._mlir_value else ((TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN) // 2)
CTileArray
comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]
data_paths
comptime data_paths = 16
epilogue_dtype
comptime epilogue_dtype = c_type if (a_type == DType.bfloat16)._mlir_value else DType.float32
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
is_lower_frag_required
comptime is_lower_frag_required = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].BM == 64) if (cta_group == 1)._mlir_value else (cta_group == 1).__bool__().__invert__()
MMA_M
comptime MMA_M = mma_shape.__getitem__[Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[Int](1)
N_dim
comptime N_dim = 0 if transpose_c else 1
num_stages
comptime num_stages = (mma_shape.__getitem__[Int](1) // c_smem_dim0 if transpose_c else c_smem_dim1) if (eq mma_shape.__getitem__[Int](0)._mlir_value, 256) else ((TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN) // 2) if (cta_group == 2)._mlir_value else TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].cg1_num_stages
rep
comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].stageN // 8)
rep_frag_size
comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].rep * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].fragment_size)
Stage
comptime Stage = OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group]
stage_contiguous_size
comptime stage_contiguous_size = c_smem_dim1
stageN
comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1
TmaOp
comptime TmaOp = TMATensorTile[c_type, c_layout, c_desc_layout]
TmaOpPtr
comptime TmaOpPtr = Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].TmaOp, tma_origin]
Methods
__init__
__init__(c_tma_op: Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue, batched].TmaOp, tma_origin]) -> Self
Initialize with pointer to TMA descriptor.
write
write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write accumulated results to global memory (2D coords).
write_batched
write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)
Write accumulated results to global memory (3D batched coords).
Args:
- c_tiles (
SMemTileArray2DRowMajor): TileTensor-based SMEM tile array for C output. - stage (
OutputStage): OutputStage with pipeline, index, and TMEM handle. - tile_coord (
Tuple): (m_tile, n_tile, batch) coordinates. - shape (
Tuple): (M, N) problem dimensions. - alpha (
Float32): Tensor scale factor (scalar).
write_splitk
write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with split-K reduction. Only last split writes to GMEM.
write_absolute_with_bounds_check
write_absolute_with_bounds_check[c_tensor_layout: Layout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin])
Write with absolute coordinates and bounds checking.
For 1D-1D grouped kernels where M coordinate is absolute.
write_with_residual
write_with_residual(self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], src_stage_idx: UInt32, beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with residual: D = lambda(accum) + beta * C.
This method extends the standard write() to add a residual term loaded from source tensor C in shared memory. The epilogue load warp pre-fetches C tiles into src_tile before this method is called.
Pipeline:
- Load accum from TMEM to registers
- Apply epilogue lambda (if present)
- Load C fragment from source SMEM
- Compute D = accum + beta * C
- Write D to output SMEM and TMA store to GMEM
Args:
- out_tiles (
SMemTileArray2DRowMajor): Output SMEM tile array (for D output). - stage (
OutputStage): OutputStage with pipeline, index, and TMEM handle. - src_tile (
SMemTileArray2DRowMajor): Source C SMEM tile array (TileTensor-based, from epilogue load warp via smem.src_tiles()). - src_stage_idx (
UInt32): Stage index into src_tile (0 or 1 for double-buffer). - beta (
Scalar): Residual scale factor. - tile_coord (
Tuple): (m_tile, n_tile) coordinates. - shape (
Tuple): (M, N) problem dimensions. - elect_one_warp (
Bool): Whether this warp is elected for coordination.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!