Mojo struct
BlockwiseFP8Accumulator
struct BlockwiseFP8Accumulator[accum_type: DType, accum_num_stages: Int, accum_num_elements: Int, is_lower_required: Bool, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_size: Int]
Register-based accumulator for blockwise FP8 matmul.
Manages upper and lower fragment tiles in registers for per-K accumulation. Unlike TMEM-based accumulation, this allows scaling in CUDA cores.
Parameters
- accum_type (
DType): Accumulator data type (typically float32). - accum_num_stages (
Int): Number of accumulator pipeline stages. - accum_num_elements (
Int): Number of elements per stage. - is_lower_required (
Bool): Whether lower fragment is needed. - block_tile_shape (
IndexList): Block tile dimensions (BM, BN, BK). - mma_shape (
IndexList): MMA operation dimensions (MMA_M, MMA_N, MMA_K). - cluster_size (
Int): Number of CTAs in the cluster.
Fields
- upper (
BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].RegTileType): - lower (
BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].RegTileType):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
AccumLayout
comptime AccumLayout = Layout[ComptimeInt[BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_stages], ComptimeInt[BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements], ComptimeInt[BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements], ComptimeInt[1]]
bits
comptime bits = 256
BK
comptime BK = block_tile_shape[2]
BM
comptime BM = block_tile_shape[0]
BN
comptime BN = block_tile_shape[1]
data_paths
comptime data_paths = 16
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
Fragments
comptime Fragments = TmemFragments[accum_type, BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size, is_lower_required=is_lower_required]
MMA_M
comptime MMA_M = mma_shape[0]
MMA_N
comptime MMA_N = mma_shape[1]
num_elements
comptime num_elements = accum_num_elements
num_elements_per_load
comptime num_elements_per_load = 8
num_stages
comptime num_stages = accum_num_stages
RegTileType
comptime RegTileType = TileTensor[accum_type, Layout[ComptimeInt[BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_stages], ComptimeInt[BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements], ComptimeInt[BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements], ComptimeInt[1]], MutExternalOrigin]
rep_frag_size
comptime rep_frag_size = (BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)
repeats
comptime repeats = (BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements // BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)
stageN
comptime stageN = (BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * 8)
Methods
__init__
__init__(out self)
Create accumulator with zero-initialized register tiles.
promote
promote[num_pipeline_stages: Int, opc: OutputPipelineConfig, num_input_stages: Int, b_scales_dtype: DType, b_scales_layout: TensorLayout, a_scales_dtype: DType, a_scales_dim0: Int, a_scales_dim1: Int](mut self, b_scales: TileTensor[b_scales_dtype, b_scales_layout, ImmutAnyOrigin], a_scales_tiles: SMemTileArray2DRowMajor[a_scales_dtype, a_scales_dim0, a_scales_dim1, num_pipeline_stages], epi_stage: EpilogueKStage[opc, num_input_stages], work_tile_coord: Tuple[Int, Int], k_iter: Int, problem_shape: StaticTuple[Int32, 3])
Load partial from TMEM, apply scales, accumulate into registers.
Core blockwise FP8 scaling: loads MMA partial from TMEM, reads A-scale from SMEM and B-scale from global memory, applies scaling, and accumulates into register tiles.
Called within with epi_ctx.per_k_stage(input_pipeline) as epi_stage:.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!