Skip to main content

Mojo struct

GroupedBlockScaledMatmulKernel

struct GroupedBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], max_groups: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(1), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True]

Grouped block-scaled matmul kernel with dynamic tensormap updates.

This kernel extends BlackwellBlockScaledMatmulKernel to support grouped GEMM:

  • Uses GroupedTileScheduler for linear tile iteration across groups
  • Uses GroupedTensormapManager for per-block updatable TMA descriptors
  • Updates tensormaps when transitioning between groups

Architecture (aligned with NVIDIA CuTe DSL grouped_blockscaled_gemm.py):

  • TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions
  • MMA warp: Waits for tensormap init, consumes tiles, performs block-scaled MMA
  • Epilogue warps: Initializes C tensormap, handles C group transitions

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_expected_bytes

comptime a_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_smem_layout.size() * size_of[a_type]())

a_smem_layout

comptime a_smem_layout = tile_layout_k_major[a_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.a_swizzle]()

a_swizzle_elems

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim1

comptime a_tile_dim1 = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)

a_tma_load_size

comptime a_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_swizzle_elems)

a_tma_rows

comptime a_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = 1

accum_type

comptime accum_type = DType.float32

ADescLayout

comptime ADescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

ATileLayout

comptime ATileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_tile_dim1].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

ATmaOp

comptime ATmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATmaTile.InnerType

ATmaTile

comptime ATmaTile = TMATile[a_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]

b_expected_bytes

comptime b_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_smem_layout.size() * size_of[b_type]())

b_smem_layout

comptime b_smem_layout = tile_layout_k_major[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, config.b_swizzle]()

b_swizzle_elems

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim1

comptime b_tile_dim1 = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN // (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M // GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group))

b_tma_load_size

comptime b_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_swizzle_elems)

b_tma_rows

comptime b_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1

BDescLayout

comptime BDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

BK

comptime BK = config.block_tile_shape.__getitem__[Int](2)

BM

comptime BM = config.block_tile_shape.__getitem__[Int](0)

BN

comptime BN = config.block_tile_shape.__getitem__[Int](1)

BTileLayout

comptime BTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_tile_dim1].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

BTmaOp

comptime BTmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTmaTile.InnerType

BTmaTile

comptime BTmaTile = TMATile[b_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]

c_smem_layout

comptime c_smem_layout = Layout.row_major(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputN)

c_swizzle_elems

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim1

comptime c_tile_dim1 = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputM if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256) if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256)._mlir_value else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group == 1) if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256) if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M == 256)._mlir_value else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group == 1) else config.AB_swapped else 64

c_tile_dim2

comptime c_tile_dim2 = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_swizzle_elems if config.AB_swapped else GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].OutputN

CDescLayout

comptime CDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim1], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

clc_consumer_arv_count

comptime clc_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group)

clc_producer_arv_count

comptime clc_producer_arv_count = 1

clc_throttle_consumer_arv_count

comptime clc_throttle_consumer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS

clc_throttle_producer_arv_count

comptime clc_throttle_producer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape.__getitem__[Int](0)

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape.__getitem__[Int](1)

CLUSTER_SIZE

comptime CLUSTER_SIZE = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N)

Context

comptime Context = KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_N]

cta_group

comptime cta_group = config.cta_group

CTileLayout

comptime CTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim2], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim1].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim2].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].c_tile_dim2].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

CTmaOp

comptime CTmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTmaTile.InnerType

CTmaTile

comptime CTmaTile = TMATile[c_type, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]

EPILOGUE_THREADS

comptime EPILOGUE_THREADS = (4 * WARP_SIZE)

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]

GroupPtrLayout

comptime GroupPtrLayout = Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

GroupPtrTile

comptime GroupPtrTile = TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin]

input_expected_bytes

comptime input_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group * (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].a_expected_bytes + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].b_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_expected_bytes)) * config)

InputTilePipelineType

comptime InputTilePipelineType = InputTilePipeline[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size]

MMA_K

comptime MMA_K = config.mma_shape.__getitem__[Int](2)

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[Int](1)

MMA_THREADS

comptime MMA_THREADS = WARP_SIZE

MmaCtx

comptime MmaCtx = MmaWarpContext[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS]

MmaEpilogueSync

comptime MmaEpilogueSync = WarpGroupBarrier[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS), 1]

MmaOp

comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages_2sm

comptime num_clc_pipeline_stages_2sm = 2

num_group_pipeline_stages

comptime num_group_pipeline_stages = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SCHEDULER_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].EPILOGUE_THREADS)

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

OutputM

comptime OutputM = config.output_tile_shape.__getitem__[Int](0)

OutputN

comptime OutputN = config.output_tile_shape.__getitem__[Int](1)

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

ProblemSizesLayout

comptime ProblemSizesLayout = Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]]

ProblemSizesTile

comptime ProblemSizesTile = TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin]

SCHEDULER_THREADS

comptime SCHEDULER_THREADS = WARP_SIZE

SchedulerType

comptime SchedulerType = GroupedTileScheduler[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BK, max_groups]

SF_K_GROUP_SIZE

comptime SF_K_GROUP_SIZE = (4 * config)

sfa_expected_bytes

comptime sfa_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfa_smem_layout.size() * size_of[sfa_dtype]())

SFA_NUM_COLS

comptime SFA_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // 32))

sfa_smem_layout

comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFADescLayout

comptime SFADescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

SFATileLayout

comptime SFATileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

SFATmaOp

comptime SFATmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATmaTile.InnerType

SFATmaTile

comptime SFATmaTile = TMATile[sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]

sfb_expected_bytes

comptime sfb_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].sfb_smem_layout.size() * size_of[sfb_dtype]())

SFB_NUM_COLS

comptime SFB_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // 32))

sfb_smem_layout

comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFBDescLayout

comptime SFBDescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

SFBTileLayout

comptime SFBTileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

SFBTmaOp

comptime SFBTmaOp = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTmaTile.InnerType

SFBTmaTile

comptime SFBTmaTile = TMATile[sfb_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]

SmemType

comptime SmemType = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]

stage_stride_cols

comptime stage_stride_cols = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N

TENSORMAP_AB_INIT_BARRIER_ID

comptime TENSORMAP_AB_INIT_BARRIER_ID = 3

TENSORMAP_AB_INIT_THREADS

comptime TENSORMAP_AB_INIT_THREADS = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TMA_LOAD_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_THREADS)

TensormapAbInitBarrier

comptime TensormapAbInitBarrier = WarpGroupBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TENSORMAP_AB_INIT_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TENSORMAP_AB_INIT_BARRIER_ID]

TensormapManagerType

comptime TensormapManagerType = GroupedTensormapManager

TilePayload

comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.BK, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFA_DIM1, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.SFB_DIM1, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_pipeline_stages]

TileWriterType

comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, config.num_accum_pipeline_stages, config.c_swizzle, config.AB_swapped, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.OutputN, config.num_output_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, 4, elementwise_compute_lambda_fn, register_based_epilogue, True]

TMA_LOAD_THREADS

comptime TMA_LOAD_THREADS = WARP_SIZE

TMATensorTileArrayA

comptime TMATensorTileArrayA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()]

TMATensorTileArrayB

comptime TMATensorTileArrayB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()]

TMATensorTileArrayC

comptime TMATensorTileArrayC = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()]

TMATensorTileArraySFA

comptime TMATensorTileArraySFA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()]

TMATensorTileArraySFB

comptime TMATensorTileArraySFB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()]

Tmem

comptime Tmem = TmemAllocation[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group]

TmemRegion

comptime TmemRegion = BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]

Methods

validate_config

static validate_config()

Compile-time validation of kernel configuration.

run

static run(a_tma_template: TMATensorTile[a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], b_tma_template: TMATensorTile[b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], c_tma_template: TMATensorTile[c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], sfa_tma_template: TMATensorTile[sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], sfb_tma_template: TMATensorTile[sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)

Grouped block-scaled GEMM kernel entry point.

This kernel processes multiple GEMM problems (groups) with dynamic tensormap updates at group boundaries.

load_input_tiles

static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], b_tma_op: TMATensorTile[b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], sfa_tma_op: TMATensorTile[sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], sfb_tma_op: TMATensorTile[sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], tiles: InputProducerStage[tiles_origin, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)

Load A, B, SFA, SFB tiles using TMA with InputProducerStage.

mma

static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].TilePayload, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SmemType.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)

Execute MMA operations using InputConsumerStage.

epilogue

static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], stage: OutputStage[config.num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].cta_group], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32 = 1)

Execute epilogue to store accumulated results.

run_2sm

static run_2sm(a_tma_template: TMATensorTile[a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], b_tma_template: TMATensorTile[b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], c_tma_template: TMATensorTile[c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], sfa_tma_template: TMATensorTile[sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], sfb_tma_template: TMATensorTile[sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, a_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].ADescLayout]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, b_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].BDescLayout]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfa_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFATileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFADescLayout]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, sfb_dtype, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].SFBDescLayout]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CLUSTER_SIZE, c_type, _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CTileLayout](), _to_legacy_layout[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue].CDescLayout]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)

Grouped block-scaled GEMM kernel with 2SM (cta_group=2) support.

This entry point uses CLC-based work distribution for proper 2SM synchronization between CTAs in a cluster. Both CTAs cooperate on each tile, with one CTA doing MMA work and both doing TMA loads.

Architecture matches the working block_scaled_matmul_kernel:

  • Scheduler warp: Produces work items via CLC barriers
  • TMA warp: Loads tiles with tensormap updates on group change
  • MMA warp: Waits on CLC, executes MMA (elected CTA only)
  • Epilogue warps: Stores results with tensormap updates

Was this page helpful?