Mojo struct
BlackwellBlockwiseFP8MatmulKernel
struct BlackwellBlockwiseFP8MatmulKernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, b_scales_layout: TensorLayout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple(1)]
Blockwise FP8 matmul kernel with register-based accumulation.
This kernel implements per-K-iteration scaling in CUDA cores:
- Load warp: TMA loads A, B, A-scales to SMEM
- MMA warp: Standard MMA (partial to TMEM)
- Epilogue warp: TMEM read → scale → register accumulate → output
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_expected_bytes
comptime a_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_smem_layout.size() * size_of[a_type]())
a_scales_expected_bytes
comptime a_scales_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_scales_smem_layout.size() * size_of[a_scales_type]())
a_scales_smem_layout
comptime a_scales_smem_layout = Layout.row_major(1, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM)
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK, config.a_swizzle]()
a_swizzle_elems
comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())
a_tile_dim0
comptime a_tile_dim0 = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM // BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N)
a_tma_load_size
comptime a_tma_load_size = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0 * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_swizzle_elems)
a_tma_rows
comptime a_tma_rows = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0
accum_dims
comptime accum_dims = get_accumulator_dims[c_smem_dim1=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = DType.float32
AccumTensor
comptime AccumTensor = TmemTensor[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].tmem_accum_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]
Accumulator
comptime Accumulator = BlockwiseFP8Accumulator[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims.__getitem__[Int](0), BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims.__getitem__[Int](1), BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].is_lower_required, config.block_tile_shape, config.mma_shape, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE]
ADescLayout
comptime ADescLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_swizzle_elems], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_swizzle_elems], ComptimeInt[1]]
AScalesLayout
comptime AScalesLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM], ComptimeInt[1]]
AScalesTmaOp
comptime AScalesTmaOp = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesTmaTile.InnerType
AScalesTmaTile
comptime AScalesTmaTile = TMATile[a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout]
ATileLayout
comptime ATileLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[1]]
ATmaOp
comptime ATmaOp = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATmaTile.InnerType
ATmaTile
comptime ATmaTile = TMATile[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout]
b_expected_bytes
comptime b_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_smem_layout.size() * size_of[b_type]())
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK, config.b_swizzle]()
b_swizzle_elems
comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())
b_tile_dim0
comptime b_tile_dim0 = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN // (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M // BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group))
b_tma_load_size
comptime b_tma_load_size = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0 * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_swizzle_elems)
b_tma_rows
comptime b_tma_rows = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0
BDescLayout
comptime BDescLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_swizzle_elems], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_swizzle_elems], ComptimeInt[1]]
BK
comptime BK = config.block_tile_shape.__getitem__[Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[Int](1)
BScalesTile
comptime BScalesTile = TileTensor[b_scales_type, b_scales_layout, ImmutAnyOrigin]
BTileLayout
comptime BTileLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[1]]
BTmaOp
comptime BTmaOp = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTmaTile.InnerType
BTmaTile
comptime BTmaTile = TMATile[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout]
c_smem_layout
comptime c_smem_layout = Layout.row_major(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN)
c_swizzle_elems
comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())
c_tile_dim0
comptime c_tile_dim0 = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM if (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M == 256) if (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M == 256)._mlir_value else (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group == 1) else 64
CDescLayout
comptime CDescLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_swizzle_elems], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_swizzle_elems], ComptimeInt[1]]
clc_consumer_arv_count
comptime clc_consumer_arv_count = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS + (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE * ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)))
clc_producer_arv_count
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count
comptime clc_throttle_consumer_arv_count = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS
clc_throttle_producer_arv_count
comptime clc_throttle_producer_arv_count = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS
CLUSTER_M
comptime CLUSTER_M = config.cluster_shape.__getitem__[Int](0)
CLUSTER_N
comptime CLUSTER_N = config.cluster_shape.__getitem__[Int](1)
CLUSTER_SIZE
comptime CLUSTER_SIZE = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N)
Context
comptime Context = KernelContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N]
cta_group
comptime cta_group = config.cta_group
CTileLayout
comptime CTileLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN], ComptimeInt[1]]
CTmaOp
comptime CTmaOp = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTmaTile.InnerType
CTmaTile
comptime CTmaTile = TMATile[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CDescLayout]
EPILOGUE_THREADS
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]
EpilogueHandle
comptime EpilogueHandle = EpilogueWarp[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]
input_expected_bytes
comptime input_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group * ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_expected_bytes + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_expected_bytes) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_scales_expected_bytes))
InputTilePipeline
comptime InputTilePipeline = InputTilePipeline[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size]
is_lower_required
comptime is_lower_required = is_lower_fragment_required[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, config.block_tile_shape]()
max_tmem_cols
comptime max_tmem_cols = 512
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[Int](1)
MMA_THREADS
comptime MMA_THREADS = WARP_SIZE
MmaCtx
comptime MmaCtx = MmaWarpContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]
MmaHandle
comptime MmaHandle = MmaWarp[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]
MmaOp
comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages
num_group_pipeline_stages
comptime num_group_pipeline_stages = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS
comptime NUM_THREADS = (((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]
Scheduler
comptime Scheduler = TileScheduler[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[Int](0), config.cluster_shape.__getitem__[Int](1), config.cluster_shape.__getitem__[Int](2)), config.raster_order, config.block_swizzle_size]
SCHEDULER_THREADS
comptime SCHEDULER_THREADS = WARP_SIZE
SmemType
comptime SmemType = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = (512 // config)
TilePayload
comptime TilePayload = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.BK, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.BK, 1, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.num_pipeline_stages]
TileWriterType
comptime TileWriterType = BlockwiseFP8TileWriter[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN, DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims.__getitem__[Int](0), BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims.__getitem__[Int](1), block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, is_lower_frag_required=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].is_lower_required, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, num_output_stages=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_output_stages, num_output_warps=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_output_warps, c_swizzle=config.c_swizzle]
TMA_LOAD_THREADS
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmem
comptime Tmem = TmemAllocation[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]
tmem_accum_layout
comptime tmem_accum_layout = Layout.row_major(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_N)
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]
Methods
load_input_tiles
static load_input_tiles[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, a_scales_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoader[a_tma_origin, a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group], b_loader: TileLoader[b_tma_origin, b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group], a_scales_loader: ScalesLoader[a_scales_tma_origin, a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group], tiles: InputProducerStage[tiles_origin, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], iter_idx: Scalar[DType.uint], elect_one_cta: Bool)
Load A, B, and A-scales tiles using TMA.
Args:
- a_loader (
TileLoader): TileLoader for A matrix. - b_loader (
TileLoader): TileLoader for B matrix. - a_scales_loader (
ScalesLoader): ScalesLoader for A-scales. - tiles (
InputProducerStage): InputProducerStage context with encapsulated tile access. - peer_cta_coord (
Tuple): Peer CTA coordinates for multicast. - work_tile_coord (
Tuple): Current work tile M/N coordinates. - iter_idx (
Scalar): K iteration index. - elect_one_cta (
Bool): Whether this is the elected CTA in the cluster.
mma
static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], accum_tensor: TmemTensor[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].tmem_accum_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group])
Execute standard MMA operations (partial results to TMEM).
For blockwise FP8, each K iteration writes a fresh partial to TMEM. The epilogue accumulates across K in registers, not TMEM. Therefore init_c is always True (unlike standard matmul).
Args:
- tiles (
InputConsumerStage): Input consumer stage with A, B, A-scales tiles. - mma_op (
MmaOpSM100_SS): The MMA operator. - accum_tensor (
TmemTensor): Typed TMEM tensor view for the accumulator stage.
validate_config
static validate_config()
Validate configuration constraints at compile time.
run
static run(a_tma_op: TMATensorTile[a_type, _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout](), _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout]()], b_tma_op: TMATensorTile[b_type, _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout](), _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout]()], c_tma_op: TMATensorTile[c_type, _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout](), _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CDescLayout]()], a_scales_tma_op: TMATensorTile[a_scales_type, _to_legacy_layout[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout]()], cluster_dim: StaticTuple[Int32, 3], num_iters: Scalar[DType.uint], b_scales: TileTensor[b_scales_type, b_scales_layout, ImmutAnyOrigin], problem_shape: StaticTuple[Int32, 3])
Kernel entry point for blockwise FP8 matmul.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!