Mojo struct
BlockwiseFP8_1D2DMatmulKernel
struct BlockwiseFP8_1D2DMatmulKernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, b_scales_layout: TensorLayout, c_device_layout: TensorLayout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], static_N: Int, static_K: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(1)]
Blockwise FP8 1D2D matmul kernel with register-based accumulation.
Combines blockwise FP8 scaling (per-K in CUDA cores) with 1D-1D offset-based work distribution for grouped GEMM in MoE layers.
Uses 3-warp specialization (Load, MMA, Epilogue) with grid-constant TMAs. Work distribution via GroupedWorkIterator1D1D using offset-based addressing.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
a_expected_bytes
comptime a_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.a_smem_layout.size() * size_of[a_type]())
a_scales_expected_bytes
comptime a_scales_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.a_scales_smem_layout.size() * size_of[a_scales_type]())
a_swizzle_elems
comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())
a_tile_dim0
comptime a_tile_dim0 = compute_tma_tile_dims[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BN, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].MMA_M, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].OutputM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_M, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_N, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group]().__getitem__(0)
a_tma_load_size
comptime a_tma_load_size = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_tile_dim0 * BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_swizzle_elems)
a_tma_rows
comptime a_tma_rows = BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_tile_dim0
accum_dims
comptime accum_dims = get_accumulator_dims[c_smem_dim1=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].OutputN, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group]()
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[128, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group]().__getitem__(1)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[128, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group]().__getitem__(0)
accum_type
comptime accum_type = DType.float32
Accumulator
comptime Accumulator = BlockwiseFP8Accumulator[DType.float32, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].accum_dims.__getitem__[Int](0), BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].accum_dims.__getitem__[Int](1), BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].is_lower_required, config.block_tile_shape, config.mma_shape, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_SIZE]
ADescLayout
comptime ADescLayout = Layout[ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_tile_dim0], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_swizzle_elems], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_swizzle_elems], ComptimeInt[1]]
AScalesLayout
comptime AScalesLayout = Layout[ComptimeInt[1], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BM], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BM], ComptimeInt[1]]
AScalesTmaOp
comptime AScalesTmaOp = TMATensorTile[a_scales_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout]()]
ATileLayout
comptime ATileLayout = Layout[ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_tile_dim0], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BK], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BK], ComptimeInt[1]]
ATmaOp
comptime ATmaOp = TMATensorTile[a_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ADescLayout]()]
b_expected_bytes
comptime b_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.b_smem_layout.size() * size_of[b_type]())
b_swizzle_elems
comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())
b_tile_dim0
comptime b_tile_dim0 = compute_tma_tile_dims[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BN, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].MMA_M, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].OutputM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_M, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_N, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group]().__getitem__(1)
b_tma_load_size
comptime b_tma_load_size = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_tile_dim0 * BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_swizzle_elems)
b_tma_rows
comptime b_tma_rows = BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_tile_dim0
BDescLayout
comptime BDescLayout = Layout[ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_tile_dim0], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_swizzle_elems], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_swizzle_elems], ComptimeInt[1]]
BK
comptime BK = config.block_tile_shape.__getitem__[Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[Int](1)
BScalesTile
comptime BScalesTile = TileTensor[b_scales_type, b_scales_layout, MutAnyOrigin]
BTileLayout
comptime BTileLayout = Layout[ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_tile_dim0], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BK], ComptimeInt[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BK], ComptimeInt[1]]
BTmaOp
comptime BTmaOp = TMATensorTile[b_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BDescLayout]()]
CDeviceTile
comptime CDeviceTile = TileTensor[c_type, c_device_layout, MutAnyOrigin]
CLUSTER_M
comptime CLUSTER_M = config.cluster_shape.__getitem__[Int](0)
CLUSTER_N
comptime CLUSTER_N = config.cluster_shape.__getitem__[Int](1)
CLUSTER_SIZE
comptime CLUSTER_SIZE = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_M * BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].CLUSTER_N)
cta_group
comptime cta_group = config.cta_group
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].opc, 32, 128]
input_expected_bytes
comptime input_expected_bytes = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group * ((BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_expected_bytes + BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].b_expected_bytes) + BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].a_scales_expected_bytes))
InputTilePipelineType
comptime InputTilePipelineType = InputTilePipeline[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].TilePayload, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.num_group_pipeline_stages, config.k_group_size]
is_lower_required
comptime is_lower_required = is_lower_fragment_required[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group, config.block_tile_shape]()
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[Int](1)
MmaCtx
comptime MmaCtx = MmaWarpContext[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].opc, 32, 128]
MmaEpilogueSync
comptime MmaEpilogueSync = WarpGroupBarrier[160, 1]
MmaOp
comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_group_pipeline_stages
comptime num_group_pipeline_stages = (BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS
comptime NUM_THREADS = WarpRole1D1D.TOTAL_THREADS
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
opc
comptime opc = OutputPipelineConfig(BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].num_accum_pipeline_stages, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].stage_stride_cols, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group)
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].opc]
SmemType
comptime SmemType = BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].MMA_N
TilePayload
comptime TilePayload = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, IndexList(BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.BM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.BK, Tuple()), IndexList(BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.BN, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.BK, Tuple()), IndexList(1, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.BM, Tuple()), BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.num_pipeline_stages]
TileWriterType
comptime TileWriterType = BlockwiseFP8TileWriter[c_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].OutputM, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].OutputN, DType.float32, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].accum_dims.__getitem__[Int](0), BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].accum_dims.__getitem__[Int](1), block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, is_lower_frag_required=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].is_lower_required, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group, num_output_stages=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].num_output_stages, num_output_warps=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].num_output_warps, c_swizzle=config.c_swizzle]
Tmem
comptime Tmem = TmemAllocation[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].opc.cta_group]
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].opc.cta_group]
WorkIterator
comptime WorkIterator = GroupedWorkIterator1D1D[static_N, config.block_tile_shape, config.cluster_shape, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group]
Methods
validate_config
static validate_config()
Compile-time validation of kernel configuration.
init_barriers
static init_barriers(elect_one_warp: Bool, elect_one_thread: Bool, a_tma_op: TMATensorTile[a_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BDescLayout]()], a_scales_tma_op: TMATensorTile[a_scales_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout]()], input_barriers: SMemArray[SharedMemBarrier, (BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (BlockwiseFP8_1D2DSmem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], tmem_dealloc: SMemArray[SharedMemBarrier, 1])
Initialize barriers and prefetch TMA descriptors.
run
static run(a_tma_op: TMATensorTile[a_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BDescLayout]()], a_scales_tma_op: TMATensorTile[a_scales_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout]()], b_scales: TileTensor[b_scales_type, b_scales_layout, MutAnyOrigin], a_offsets: TileTensor[DType.uint32, GMEMLayout1D, MutAnyOrigin], expert_ids: TileTensor[DType.int32, GMEMLayout1D, MutAnyOrigin], expert_scales: TileTensor[DType.float32, GMEMLayout1D, MutAnyOrigin], c_device: TileTensor[c_type, c_device_layout, MutAnyOrigin], num_active_experts: Int, K: UInt32)
Grouped 1D-1D blockwise FP8 GEMM kernel entry point.
Uses grid-constant TMAs with offset-based addressing for 1D-1D layout. Accumulates in registers with per-K scaling in CUDA cores.
load_input_tiles
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ATileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BTileLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].BDescLayout]()], a_scales_tma_op: TMATensorTile[a_scales_type, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout](), _to_index_list[BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout.rank, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].AScalesLayout]()], tiles: ProducerTiles[tiles_origin, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].TilePayload, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_ctx: GroupedWorkContext1D1D, iter_idx: Int, elect_one_cta: Bool)
Load A, B, and A-scales tiles using TMA.
mma
static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].TilePayload, BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlockwiseFP8_1D2DMatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, c_device_layout, transpose_b, config, static_N, static_K, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32)
Execute standard MMA operations (partial results to TMEM).
For blockwise FP8, each K iteration writes a fresh partial to TMEM. The epilogue accumulates across K in registers, not TMEM. Therefore init_c is always True.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!