Skip to main content

Mojo struct

TileWriter

struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_rank: Int, c_tile_shape: IndexList[c_rank], c_desc_shape: IndexList[c_rank], //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], opc: OutputPipelineConfig, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_dim0: Int, c_smem_dim1: Int, num_output_stages: Int, num_output_warps: Int, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, batched: Bool = False, problem_n: Int = 0]

Output tile writer for SM100 matmul epilogue.

Stores pointer to TMA descriptor. SMEM tiles passed per-call.

Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.

The opc (OutputPipelineConfig) parameter must match the config used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.

Fields

  • c_tma_op (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOpPtr):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

accum_tile_layout

comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN)

AccumTmemArray

comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].num_stages, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group]

bits

comptime bits = 256

BM

comptime BM = block_tile_shape[0]

BN

comptime BN = block_tile_shape[1]

c_smem_layout

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

cta_group

comptime cta_group = opc.cta_group

CTileArray

comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]

data_paths

comptime data_paths = 16

epc

comptime epc = EpilogueConfig.create(MMA_M=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_M, MMA_N=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].MMA_N, stageN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN, cta_group=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].cta_group, transpose_c=transpose_c, BM=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BM, BN=TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].BN)

epilogue_dtype

comptime epilogue_dtype = DType.bfloat16 if (c_type == DType.bfloat16) if (a_type == c_type) else (a_type == c_type) else DType.float32

fragment_size

comptime fragment_size = (128 // WARP_SIZE)

is_lower_frag_required

comptime is_lower_frag_required = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.is_lower_frag_required

MMA_M

comptime MMA_M = mma_shape[0]

MMA_N

comptime MMA_N = mma_shape[1]

N_dim

comptime N_dim = 0 if transpose_c else 1

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = opc.num_stages

num_stages

comptime num_stages = TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].epc.num_stages

rep

comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].stageN // 8)

rep_frag_size

comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].fragment_size * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].rep)

Stage

comptime Stage = OutputStage[opc]

stage_contiguous_size

comptime stage_contiguous_size = c_smem_dim1

stage_stride_cols

comptime stage_stride_cols = opc.stage_stride_cols

stageN

comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1

TmaOp

comptime TmaOp = TMATensorTile[c_type, c_rank, c_tile_shape, c_desc_shape]

TmaOpPtr

comptime TmaOpPtr = Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOp, tma_origin]

Methods

__init__

__init__(c_tma_op: Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, opc, c_swizzle, transpose_c, c_smem_dim0, c_smem_dim1, num_output_stages, num_output_warps, elementwise_lambda_fn, elementwise_compute_lambda_fn, register_based_epilogue, batched, problem_n].TmaOp, tma_origin]) -> Self

Initialize with pointer to TMA descriptor.

write

write(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write accumulated results to global memory (2D coords).

write_batched

write_batched(self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], tile_coord: Tuple[UInt32, UInt32, UInt32], shape: Tuple[UInt32, UInt32], alpha: Float32 = 1)

Write accumulated results to global memory (3D batched coords).

Args:

  • c_tiles (SMemTileArray2DRowMajor): TileTensor-based SMEM tile array for C output.
  • stage (OutputStage): OutputStage with pipeline, index, and TMEM handle.
  • tile_coord (Tuple): (m_tile, n_tile, batch) coordinates.
  • shape (Tuple): (M, N) problem dimensions.
  • alpha (Float32): Tensor scale factor (scalar).

write_splitk

write_splitk[reduction_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], scheduler: TileScheduler[scheduler.num_stages, scheduler.reduction_tile_shape, scheduler.cluster_shape, scheduler.rasterize_order, scheduler.block_swizzle_size, scheduler.num_split_k], reduction_tensor: TileTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with split-K reduction. Only last split writes to GMEM.

write_absolute_with_bounds_check

write_absolute_with_bounds_check[c_tensor_layout: TensorLayout](self, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], output_stage: OutputStage[opc], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])

Write with absolute coordinates and bounds checking.

For 1D-1D grouped kernels where M coordinate is absolute.

write_with_residual

write_with_residual(self, out_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], stage: OutputStage[opc], src_tile: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], src_stage_idx: UInt32, beta: Scalar[c_type], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with residual: D = lambda(accum) + beta * C.

This method extends the standard write() to add a residual term loaded from source tensor C in shared memory. The epilogue load warp pre-fetches C tiles into src_tile before this method is called.

Pipeline:

  1. Load accum from TMEM to registers
  2. Apply epilogue lambda (if present)
  3. Load C fragment from source SMEM
  4. Compute D = accum + beta * C
  5. Write D to output SMEM and TMA store to GMEM

Args:

  • out_tiles (SMemTileArray2DRowMajor): Output SMEM tile array (for D output).
  • stage (OutputStage): OutputStage with pipeline, index, and TMEM handle.
  • src_tile (SMemTileArray2DRowMajor): Source C SMEM tile array (TileTensor-based, from epilogue load warp via smem.src_tiles()).
  • src_stage_idx (UInt32): Stage index into src_tile (0 or 1 for double-buffer).
  • beta (Scalar): Residual scale factor.
  • tile_coord (Tuple): (m_tile, n_tile) coordinates.
  • shape (Tuple): (M, N) problem dimensions.
  • elect_one_warp (Bool): Whether this warp is elected for coordination.

Was this page helpful?