Mojo struct
GroupedBlockScaledMatmulKernel
struct GroupedBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], max_groups: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True]
Grouped block-scaled matmul kernel with dynamic tensormap updates.
This kernel extends BlackwellBlockScaledMatmulKernel to support grouped GEMM:
- Uses GroupedTileScheduler for linear tile iteration across groups
- Uses GroupedTensormapManager for per-block updatable TMA descriptors
- Updates tensormaps when transitioning between groups
Architecture (aligned with NVIDIA CuTe DSL grouped_blockscaled_gemm.py):
- TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions
- MMA warp: Waits for tensormap init, consumes tiles, performs block-scaled MMA
- Epilogue warps: Initializes C tensormap, handles C group transitions
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_expected_bytes
comptime a_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_smem_layout.size() * size_of[a_type]())
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.a_swizzle]()
a_swizzle_elems
comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())
a_tile_dim1
comptime a_tile_dim1 = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)
a_tma_load_size
comptime a_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_swizzle_elems)
a_tma_rows
comptime a_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = DType.float32
ADescLayout
comptime ADescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
ATileLayout
comptime ATileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
ATmaOp
comptime ATmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATmaTile.InnerType
ATmaTile
comptime ATmaTile = TMATile[a_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]
b_expected_bytes
comptime b_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_smem_layout.size() * size_of[b_type]())
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.b_swizzle]()
b_swizzle_elems
comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())
b_tile_dim1
comptime b_tile_dim1 = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN // (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M // GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group))
b_tma_load_size
comptime b_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_swizzle_elems)
b_tma_rows
comptime b_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1
BDescLayout
comptime BDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
BK
comptime BK = config.block_tile_shape.__getitem__[Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[Int](1)
BTileLayout
comptime BTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
BTmaOp
comptime BTmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTmaTile.InnerType
BTmaTile
comptime BTmaTile = TMATile[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]
c_smem_layout
comptime c_smem_layout = Layout.row_major(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputN)
c_swizzle_elems
comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())
c_tile_dim1
comptime c_tile_dim1 = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputM if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256) if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256)._mlir_value else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group == 1) if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256) if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256)._mlir_value else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group == 1) else config.AB_swapped else 64
c_tile_dim2
comptime c_tile_dim2 = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_swizzle_elems if config.AB_swapped else GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputN
CDescLayout
comptime CDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim1], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
clc_consumer_arv_count
comptime clc_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group)
clc_producer_arv_count
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count
comptime clc_throttle_consumer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS
clc_throttle_producer_arv_count
comptime clc_throttle_producer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS
CLUSTER_M
comptime CLUSTER_M = config.cluster_shape.__getitem__[Int](0)
CLUSTER_N
comptime CLUSTER_N = config.cluster_shape.__getitem__[Int](1)
CLUSTER_SIZE
comptime CLUSTER_SIZE = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)
Context
comptime Context = KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N]
cta_group
comptime cta_group = config.cta_group
CTileLayout
comptime CTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim2], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim1].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim2].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim2].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
CTmaOp
comptime CTmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTmaTile.InnerType
CTmaTile
comptime CTmaTile = TMATile[c_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]
EPILOGUE_THREADS
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]
GroupPtrLayout
comptime GroupPtrLayout = Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
GroupPtrTile
comptime GroupPtrTile = TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin]
input_expected_bytes
comptime input_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_expected_bytes + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_expected_bytes)) * config)
InputTilePipelineType
comptime InputTilePipelineType = InputTilePipeline[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[Int](1)
MMA_THREADS
comptime MMA_THREADS = WARP_SIZE
MmaCtx
comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]
MmaEpilogueSync
comptime MmaEpilogueSync = WarpGroupBarrier[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS), 1]
MmaOp
comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stages_2sm
comptime num_clc_pipeline_stages_2sm = 2
num_group_pipeline_stages
comptime num_group_pipeline_stages = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)
num_output_stages
comptime num_output_stages = config.num_output_stages
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS
comptime NUM_THREADS = (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
ProblemSizesLayout
comptime ProblemSizesLayout = Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]]
ProblemSizesTile
comptime ProblemSizesTile = TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin]
SCHEDULER_THREADS
comptime SCHEDULER_THREADS = WARP_SIZE
SchedulerType
comptime SchedulerType = GroupedTileScheduler[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, max_groups]
SF_K_GROUP_SIZE
comptime SF_K_GROUP_SIZE = (4 * config)
sfa_expected_bytes
comptime sfa_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_smem_layout.size() * size_of[sfa_dtype]())
SFA_NUM_COLS
comptime SFA_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // 32))
sfa_smem_layout
comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SFADescLayout
comptime SFADescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
SFATileLayout
comptime SFATileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
SFATmaOp
comptime SFATmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATmaTile.InnerType
SFATmaTile
comptime SFATmaTile = TMATile[sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]
sfb_expected_bytes
comptime sfb_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_smem_layout.size() * size_of[sfb_dtype]())
SFB_NUM_COLS
comptime SFB_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // 32))
sfb_smem_layout
comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SFBDescLayout
comptime SFBDescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
SFBTileLayout
comptime SFBTileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
SFBTmaOp
comptime SFBTmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTmaTile.InnerType
SFBTmaTile
comptime SFBTmaTile = TMATile[sfb_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]
SmemType
comptime SmemType = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N
TENSORMAP_AB_INIT_BARRIER_ID
comptime TENSORMAP_AB_INIT_BARRIER_ID = 3
TENSORMAP_AB_INIT_THREADS
comptime TENSORMAP_AB_INIT_THREADS = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS)
TensormapAbInitBarrier
comptime TensormapAbInitBarrier = WarpGroupBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TENSORMAP_AB_INIT_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TENSORMAP_AB_INIT_BARRIER_ID]
TensormapManagerType
comptime TensormapManagerType = GroupedTensormapManager
TilePayload
comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM1, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM1, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]
TileWriterType
comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, config.num_output_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, 4, elementwise_compute_lambda_fn, register_based_epilogue, True]
TMA_LOAD_THREADS
comptime TMA_LOAD_THREADS = WARP_SIZE
TMATensorTileArrayA
comptime TMATensorTileArrayA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()]
TMATensorTileArrayB
comptime TMATensorTileArrayB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()]
TMATensorTileArrayC
comptime TMATensorTileArrayC = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()]
TMATensorTileArraySFA
comptime TMATensorTileArraySFA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()]
TMATensorTileArraySFB
comptime TMATensorTileArraySFB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()]
Tmem
comptime Tmem = TmemAllocation[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]
TmemRegion
comptime TmemRegion = BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]
Methods
validate_config
static validate_config()
Compile-time validation of kernel configuration.
run
static run(a_tma_template: TMATensorTile[a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], b_tma_template: TMATensorTile[b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], c_tma_template: TMATensorTile[c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], sfa_tma_template: TMATensorTile[sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], sfb_tma_template: TMATensorTile[sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)
Grouped block-scaled GEMM kernel entry point.
This kernel processes multiple GEMM problems (groups) with dynamic tensormap updates at group boundaries.
load_input_tiles
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], b_tma_op: TMATensorTile[b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], sfa_tma_op: TMATensorTile[sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], sfb_tma_op: TMATensorTile[sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], tiles: InputProducerStage[tiles_origin, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)
Load A, B, SFA, SFB tiles using TMA with InputProducerStage.
mma
static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)
Execute MMA operations using InputConsumerStage.
epilogue
static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], stage: OutputStage[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32 = 1)
Execute epilogue to store accumulated results.
run_2sm
static run_2sm(a_tma_template: TMATensorTile[a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], b_tma_template: TMATensorTile[b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], c_tma_template: TMATensorTile[c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], sfa_tma_template: TMATensorTile[sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], sfb_tma_template: TMATensorTile[sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)
Grouped block-scaled GEMM kernel with 2SM (cta_group=2) support.
This entry point uses CLC-based work distribution for proper 2SM synchronization between CTAs in a cluster. Both CTAs cooperate on each tile, with one CTA doing MMA work and both doing TMA loads.
Architecture matches the working block_scaled_matmul_kernel:
- Scheduler warp: Produces work items via CLC barriers
- TMA warp: Loads tiles with tensormap updates on group change
- MMA warp: Waits on CLC, executes MMA (elected CTA only)
- Epilogue warps: Stores results with tensormap updates
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!