Mojo struct
TileWriter
@register_passable(trivial)
struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_layout: Layout, c_desc_layout: Layout, //, a_type: DType, b_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], c_smem_layout: Layout, num_output_stages: Int, stage_stride_cols: UInt, num_output_warps: UInt, max_tmem_cols: UInt = 512, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True]
Output tile writer for SM100 matmul epilogue.
Stores pointer to TMA descriptor. SMEM tiles passed per-call.
Fields
- c_tma_op (
TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].TmaOpPtr):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
accum_type
comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type
bits
comptime bits = 256
block_tile_shape
comptime block_tile_shape = config.block_tile_shape
BM
comptime BM = TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].block_tile_shape.__getitem__[3, DType.int64, Int](1)
c_swizzle
comptime c_swizzle = config.c_swizzle
cg1_num_stages
comptime cg1_num_stages = (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].stageN)
cg2_num_stages
comptime cg2_num_stages = (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_layout.shape[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].N_dim].value()) if (eq config.mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].stageN) // 2)
cta_group
comptime cta_group = config.cta_group
CTileArray
comptime CTileArray = SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128]
data_paths
comptime data_paths = 16
epilogue_dtype
comptime epilogue_dtype = c_type if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 80) else DType.float32
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
is_lower_frag_required
comptime is_lower_frag_required = (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].BM == 64) if (eq config.cta_group._mlir_value, 1) else (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].cta_group == 1).__bool__().__invert__()
MMA_M
comptime MMA_M = TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].mma_shape.__getitem__[3, DType.int64, Int](1)
mma_shape
comptime mma_shape = config.mma_shape
N_dim
comptime N_dim = 0 if config.AB_swapped else 1
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = Int(config)
num_stages
comptime num_stages = (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_layout.shape[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].N_dim].value()) if (eq config.mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].stageN) // 2) if (eq config.cta_group._mlir_value, 2) else TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].cg1_num_stages
rep
comptime rep = (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].stageN // 8)
rep_frag_size
comptime rep_frag_size = (TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].rep * TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].fragment_size)
Stage
comptime Stage = OutputStage[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, Int(stage_stride_cols), TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
stage_contiguous_size
comptime stage_contiguous_size = c_smem_layout.shape[1].value()
stageN
comptime stageN = c_smem_layout.shape[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].N_dim].value()
TmaOp
comptime TmaOp = TMATensorTile[c_type, c_layout, c_desc_layout]
TmaOpPtr
comptime TmaOpPtr = Pointer[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].TmaOp, tma_origin]
transpose_c
comptime transpose_c = config.AB_swapped
Methods
__init__
__init__(c_tma_op: Pointer[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].TmaOp, tma_origin]) -> Self
Initialize with pointer to TMA descriptor.
write
write(self, c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], stage: OutputStage[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, Int(stage_stride_cols), TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write accumulated results to global memory.
write_splitk
write_splitk[reduction_layout: Layout](self, c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], stage: OutputStage[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, Int(stage_stride_cols), TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], reduction_tensor: LayoutTensor[TileWriter[a_type, b_type, transpose_b, config, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, max_tmem_cols, elementwise_compute_lambda_fn, register_based_epilogue].accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)
Write with split-K reduction. Only last split writes to GMEM.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!