Skip to main content

Mojo struct

BlockwiseFP8Accumulator

struct BlockwiseFP8Accumulator[accum_type: DType, accum_layout: Layout, is_lower_required: Bool, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_size: Int]

Register-based accumulator for blockwise FP8 matmul.

Manages upper and lower fragment tiles in registers for per-K accumulation. Unlike TMEM-based accumulation, this allows scaling in CUDA cores.

Parameters

  • accum_type (DType): Accumulator data type (typically float32).
  • accum_layout (Layout): 2D layout (num_stages, num_elements) for register tiles.
  • is_lower_required (Bool): Whether lower fragment is needed (based on cta_group/MMA_M).
  • block_tile_shape (IndexList): Block tile dimensions (BM, BN, BK).
  • mma_shape (IndexList): MMA operation dimensions (MMA_M, MMA_N, MMA_K).
  • cluster_size (Int): Number of CTAs in the cluster.

Fields

  • upper (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].UpperTile):
  • lower (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].LowerTile):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

bits

comptime bits = 256

BK

comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)

data_paths

comptime data_paths = 16

fragment_size

comptime fragment_size = (128 // WARP_SIZE)

Fragments

comptime Fragments = TmemFragments[accum_type, BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size, is_lower_required=is_lower_required]

LowerTile

comptime LowerTile = LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_M

comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)

num_elements

comptime num_elements = accum_layout.shape[1].value()

num_elements_per_load

comptime num_elements_per_load = 8

num_stages

comptime num_stages = accum_layout.shape[0].value()

rep_frag_size

comptime rep_frag_size = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)

repeats

comptime repeats = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].num_elements // BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].fragment_size)

stageN

comptime stageN = (BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_required, block_tile_shape, mma_shape, cluster_size].repeats * 8)

UpperTile

comptime UpperTile = LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

Methods

__init__

__init__(out self)

Create accumulator with zero-initialized register tiles.

promote

promote[num_pipeline_stages: Int, num_accum_pipeline_stages: Int, stage_stride_cols: Int, cta_group: Int, num_input_stages: Int, b_scales_dtype: DType, b_scales_layout: Layout, a_scales_dtype: DType, a_scales_smem_layout: Layout](mut self, b_scales: LayoutTensor[b_scales_dtype, b_scales_layout, MutAnyOrigin], a_scales_tiles: SMemTileArray[a_scales_dtype, a_scales_smem_layout, num_pipeline_stages, 128], epi_stage: EpilogueKStage[num_accum_pipeline_stages, stage_stride_cols, cta_group, num_input_stages], work_tile_coord: Tuple[UInt, UInt], k_iter: Scalar[DType.uindex], problem_shape: StaticTuple[Int32, 3])

Load partial from TMEM, apply scales, accumulate into registers.

Core blockwise FP8 scaling: loads MMA partial from TMEM, reads A-scale from SMEM and B-scale from global memory, applies scaling, and accumulates into register tiles.

Called within with epi_ctx.per_k_stage(input_pipeline) as epi_stage:.

Was this page helpful?