Mojo struct
BlockScaledTileWriter
@register_passable(trivial)
struct BlockScaledTileWriter[tma_origin: ImmutOrigin, c_type: DType, c_layout: Layout, c_desc_layout: Layout, //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, num_accum_pipeline_stages: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_layout: Layout, num_output_stages: Int, stage_stride_cols: Int, num_output_warps: Int]
Output tile writer for SM100 block-scaled matmul epilogue.
Uses TMAStoreExecutor with batched=True for 3D (M, N, Batch) TMA stores. All operations use structured building blocks from tile_writer.mojo.
Parameters are passed explicitly to work with BlockScaledMatmulConfig.
The stage_stride_cols parameter must match the value used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.
Fields
- c_tma_op (
BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].TmaOpPtr):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
accum_tile_layout
comptime accum_tile_layout = Layout.row_major(BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].BM, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN)
AccumTmemArray
comptime AccumTmemArray = TmemArrayType[accum_type, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].accum_tile_layout, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].num_stages, cta_group=cta_group]
bits
comptime bits = 256
BM
comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)
cg1_num_stages
comptime cg1_num_stages = (BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_N // BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN)
cg2_num_stages
comptime cg2_num_stages = (mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_layout.shape[BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].N_dim].value()) if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_N // BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN) // 2)
CTileArray
comptime CTileArray = SMemTileArray[c_type, c_smem_layout, num_output_stages, 128]
data_paths
comptime data_paths = 16
epilogue_dtype
comptime epilogue_dtype = c_type if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 80) else DType.float32
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
is_lower_frag_required
comptime is_lower_frag_required = (BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].BM == 64) if (eq cta_group._mlir_value, 1) else (cta_group == 1).__bool__().__invert__()
MMA_M
comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)
N_dim
comptime N_dim = 0 if transpose_c else 1
num_stages
comptime num_stages = (mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_layout.shape[BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].N_dim].value()) if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_N // BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN) // 2) if (eq cta_group._mlir_value, 2) else BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].cg1_num_stages
rep_frag_size
comptime rep_frag_size = (BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].repeat * BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].fragment_size)
repeat
comptime repeat = (BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN // 8)
SMEMWriter
comptime SMEMWriter = TMEMToSMemWriter[c_type, accum_type, c_smem_layout, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].BM, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].BN, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_M, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_N, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN, cta_group, num_output_warps, c_swizzle, transpose_c]
Stage
comptime Stage = OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group]
stage_contiguous_size
comptime stage_contiguous_size = c_smem_layout.shape[1].value()
stageN
comptime stageN = c_smem_layout.shape[BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].N_dim].value()
StoreExecutor
comptime StoreExecutor = TMAStoreExecutor[c_type, c_smem_layout, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].BM, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].BN, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_M, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].MMA_N, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stageN, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].stage_contiguous_size, cta_group, c_swizzle, transpose_c, BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].is_lower_frag_required, True]
TmaOp
comptime TmaOp = TMATensorTile[c_type, c_layout, c_desc_layout]
TmaOpPtr
comptime TmaOpPtr = Pointer[BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].TmaOp, tma_origin]
Methods
__init__
__init__(c_tma_op: Pointer[BlockScaledTileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps].TmaOp, tma_origin]) -> Self
Initialize with pointer to TMA descriptor.
write
write(self, c_tiles: SMemTileArray[c_type, c_smem_layout, num_output_stages, 128], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], c_coord: Tuple[UInt32, UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])
Write accumulated results to global memory.
Args:
- c_tiles (
SMemTileArray): SMEM tile array for C output (double-buffered). - stage (
OutputStage): OutputStage with pipeline, index, and TMEM handle. - c_coord (
Tuple): (m_tile, n_tile, batch) coordinates. - c_shape (
Tuple): (M, N) problem dimensions.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!