Mojo struct
TileWriter
struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_rank: Int, c_tile_shape: IndexList[c_rank], c_desc_shape: IndexList[c_rank], //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], opc: OutputPipelineConfig, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, num_output_warps: Int, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, batched: Bool = False, problem_n: Int = 0]
Output tile writer for SM100 matmul epilogue.
Stores pointer to TMA descriptor. SMEM tiles passed per-call.
Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.
The opc (OutputPipelineConfig) parameter must match the config used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.
Fields
- c_tma_op (
TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOpPtr):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
accum_tile_layout
comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN)
AccumTmemArray
comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].num_stages, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group]
bits
comptime bits = 256
BM
comptime BM = block_tile_shape[0]
BN
comptime BN = block_tile_shape[1]
c_smem_layout
comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)
cta_group
comptime cta_group = opc.cta_group
CTileArray
comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]
data_paths
comptime data_paths = 16
epc
comptime epc = EpilogueConfig.create(MMA_M=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_M, MMA_N=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_N, stageN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group, transpose_c=transpose_c, BM=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, BN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BN)
epilogue_dtype
comptime epilogue_dtype = DType.bfloat16 if (c_type == DType.bfloat16) if (a_type == c_type) else (a_type == c_type) else DType.float32
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
is_lower_frag_required
comptime is_lower_frag_required = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.is_lower_frag_required
MMA_M
comptime MMA_M = mma_shape[0]
MMA_N
comptime MMA_N = mma_shape[1]
N_dim
comptime N_dim = 0 if transpose_c else 1
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = opc.num_stages
num_stages
comptime num_stages = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.num_stages
rep
comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN // 8)
rep_frag_size
comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].fragment_size * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].rep)
Stage
comptime Stage = OutputStage[opc]
stage_contiguous_size
comptime stage_contiguous_size = c_smem_dim1
stage_stride_cols
comptime stage_stride_cols = opc.stage_stride_cols
stageN
comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1
TmaOp
comptime TmaOp = TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape]
TmaOpPtr
comptime TmaOpPtr = Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOp, tma_origin]
Methods
__init__
__init__(c_tma_op: Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOp, tma_origin]) -> Self
Initialize with pointer to TMA descriptor.
write
write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write accumulated results to global memory (2D coords).
write_batched
write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)
Write accumulated results to global memory (3D batched coords).
Args:
- c_tiles (
SMemTileArray2DRowMajor): TileTensor-based SMEM tile array for C output. - stage (
OutputStage): OutputStage with pipeline, index, and TMEM handle. - tile_coord (
Tuple): (m_tile, n_tile, batch) coordinates. - shape (
Tuple): (M, N) problem dimensions. - alpha (
Float32): Tensor scale factor (scalar).
write_splitk
write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], scheduler: TileScheduler[scheduler.num_stages, scheduler.reduction_tile_shape, scheduler.cluster_shape, scheduler.rasterize_order, scheduler.block_swizzle_size, scheduler.num_split_k], reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with split-K reduction. Only last split writes to GMEM.
write_absolute_with_bounds_check
write_absolute_with_bounds_check[c_tensor_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])
Write with absolute coordinates and bounds checking.
For 1D-1D grouped kernels where M coordinate is absolute.
write_with_residual
write_with_residual(self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], src_stage_idx: UInt32, beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with residual: D = lambda(accum) + beta * C.
This method extends the standard write() to add a residual term loaded from source tensor C in shared memory. The epilogue load warp pre-fetches C tiles into src_tile before this method is called.
Pipeline:
- Load accum from TMEM to registers
- Apply epilogue lambda (if present)
- Load C fragment from source SMEM
- Compute D = accum + beta * C
- Write D to output SMEM and TMA store to GMEM
Args:
- out_tiles (
SMemTileArray2DRowMajor): Output SMEM tile array (for D output). - stage (
OutputStage): OutputStage with pipeline, index, and TMEM handle. - src_tile (
SMemTileArray2DRowMajor): Source C SMEM tile array (TileTensor-based, from epilogue load warp via smem.src_tiles()). - src_stage_idx (
UInt32): Stage index into src_tile (0 or 1 for double-buffer). - beta (
Scalar): Residual scale factor. - tile_coord (
Tuple): (m_tile, n_tile) coordinates. - shape (
Tuple): (M, N) problem dimensions. - elect_one_warp (
Bool): Whether this warp is elected for coordination.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!