Skip to main content

Mojo struct

TileWriter

struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_rank: Int, c_tile_shape: IndexList[c_rank], c_desc_shape: IndexList[c_rank], //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], opc: OutputPipelineConfig, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, num_output_warps: Int, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, batched: Bool = False, problem_n: Int = 0]

Output tile writer for SM100 matmul epilogue.

Stores pointer to TMA descriptor. SMEM tiles passed per-call.

Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.

The opc (OutputPipelineConfig) parameter must match the config used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.

Fields​

  • ​c_tma_op (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOpPtr):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_tile_layout​

comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN)

AccumTmemArray​

comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].num_stages, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group]

bits​

comptime bits = 256

BM​

comptime BM = block_tile_shape[0]

BN​

comptime BN = block_tile_shape[1]

c_smem_layout​

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

cta_group​

comptime cta_group = opc.cta_group

CTileArray​

comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]

data_paths​

comptime data_paths = 16

epc​

comptime epc = EpilogueConfig.create(MMA_M=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_M, MMA_N=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_N, stageN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group, transpose_c=transpose_c, BM=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, BN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BN)

epilogue_dtype​

comptime epilogue_dtype = TileWriter.get_epilogue_dtype()

fragment_size​

comptime fragment_size = (128 // WARP_SIZE)

is_lower_frag_required​

comptime is_lower_frag_required = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.is_lower_frag_required

MMA_M​

comptime MMA_M = mma_shape[0]

MMA_N​

comptime MMA_N = mma_shape[1]

N_dim​

comptime N_dim = 0 if transpose_c else 1

num_accum_pipeline_stages​

comptime num_accum_pipeline_stages = opc.num_stages

num_stages​

comptime num_stages = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.num_stages

rep​

comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN // 8)

rep_frag_size​

comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].fragment_size * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].rep)

Stage​

comptime Stage = OutputStage[opc]

stage_contiguous_size​

comptime stage_contiguous_size = c_smem_dim1

stage_stride_cols​

comptime stage_stride_cols = opc.stage_stride_cols

stageN​

comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1

TmaOp​

comptime TmaOp = TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape]

TmaOpPtr​

comptime TmaOpPtr = Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]

Methods​

__init__​

__init__(c_tma_op: Pointer[TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape], tma_origin]) -> Self

Initialize with pointer to TMA descriptor.

get_epilogue_dtype​

static get_epilogue_dtype() -> DType

Returns:

DType

write​

write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write accumulated results to global memory (2D coords).

write_batched​

write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)

Write accumulated results to global memory (3D batched coords).

Args:

write_splitk​

write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], scheduler: TileScheduler, reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with split-K reduction. Only last split writes to GMEM.

write_absolute_with_bounds_check​

write_absolute_with_bounds_check[c_tensor_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])

Write with absolute coordinates and bounds checking.

For 1D-1D grouped kernels where M coordinate is absolute.

write_with_residual​

write_with_residual(self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], src_stage_idx: UInt32, beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with residual: D = lambda(accum) + beta * C.

This method extends the standard write() to add a residual term loaded from source tensor C in shared memory. The epilogue load warp pre-fetches C tiles into src_tile before this method is called.

Pipeline:

  1. Load accum from TMEM to registers
  2. Apply epilogue lambda (if present)
  3. Load C fragment from source SMEM
  4. Compute D = accum + beta * C
  5. Write D to output SMEM and TMA store to GMEM

Args: