Mojo struct
BlockwiseFP8Accumulator
struct BlockwiseFP8Accumulator[accum_type: DType, accum_layout: Layout, is_lower_required: Bool, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_size: Int]
Register-based accumulator for blockwise FP8 matmul.
Manages upper and lower fragment tiles in registers for per-K accumulation. Unlike TMEM-based accumulation, this allows scaling in CUDA cores.
Parameters
- accum_type (
DType): Accumulator data type (typically float32). - accum_layout (
Layout): 2D layout (num_stages, num_elements) for register tiles. - is_lower_required (
Bool): Whether lower fragment is needed (based on cta_group/MMA_M). - block_tile_shape (
IndexList): Block tile dimensions (BM, BN, BK). - mma_shape (
IndexList): MMA operation dimensions (MMA_M, MMA_N, MMA_K). - cluster_size (
Int): Number of CTAs in the cluster.
Fields
- upper (
BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].UpperTile): - lower (
BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].LowerTile):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
bits
comptime bits = 256
BK
comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)
data_paths
comptime data_paths = 16
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
Fragments
comptime Fragments = TmemFragments[accum_type, BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size, is_lower_required=is_lower_required]
LowerTile
comptime LowerTile = LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
MMA_M
comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)
num_elements
comptime num_elements = accum_layout.shape[1].value()
num_elements_per_load
comptime num_elements_per_load = 8
num_stages
comptime num_stages = accum_layout.shape[0].value()
rep_frag_size
comptime rep_frag_size = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)
repeats
comptime repeats = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements // BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)
stageN
comptime stageN = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * 8)
UpperTile
comptime UpperTile = LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
Methods
__init__
__init__(out self)
Create accumulator with zero-initialized register tiles.
promote
promote[num_pipeline_stages: Int, num_accum_pipeline_stages: Int, stage_stride_cols: Int, cta_group: Int, num_input_stages: Int, b_scales_dtype: DType, b_scales_layout: Layout, a_scales_dtype: DType, a_scales_smem_layout: Layout](mut self, b_scales: LayoutTensor[b_scales_dtype, b_scales_layout, MutAnyOrigin], a_scales_tiles: SMemTileArray[a_scales_dtype, a_scales_smem_layout, num_pipeline_stages, 128], epi_stage: EpilogueKStage[num_accum_pipeline_stages, stage_stride_cols, cta_group, num_input_stages], work_tile_coord: Tuple[UInt, UInt], k_iter: Scalar[DType.uindex], problem_shape: StaticTuple[Int32, 3])
Load partial from TMEM, apply scales, accumulate into registers.
Core blockwise FP8 scaling: loads MMA partial from TMEM, reads A-scale from SMEM and B-scale from global memory, applies scaling, and accumulates into register tiles.
Called within with epi_ctx.per_k_stage(input_pipeline) as epi_stage:.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!