Mojo struct
GroupedBlockScaledMatmulKernel
struct GroupedBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], max_groups: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1)), elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None]
Grouped block-scaled matmul kernel with dynamic tensormap updates.
This kernel extends BlackwellBlockScaledMatmulKernel to support grouped GEMM:
- Uses GroupedTileScheduler for linear tile iteration across groups
- Uses GroupedTensormapManager for per-block updatable TMA descriptors
- Updates tensormaps when transitioning between groups
Architecture (aligned with NVIDIA CuTe DSL grouped_blockscaled_gemm.py):
- TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions
- MMA warp: Waits for tensormap init, consumes tiles, performs block-scaled MMA
- Epilogue warps: Initializes C tensormap, handles C group transitions
Implemented traits​
AnyType,
ImplicitlyDestructible
comptime members​
a_expected_bytes​
comptime a_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[a_type]())
a_smem_layout​
comptime a_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // 8)]()), Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK // (config.a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (config.a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK == (config.a_swizzle.bytes() // size_of[a_type]())) else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM * (config.a_swizzle.bytes() // size_of[a_type]()))]()))).to_layout()
a_swizzle_elems​
comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())
a_tile_dim0​
comptime a_tile_dim0 = compute_tma_tile_dims[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[0]
a_tma_load_size​
comptime a_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_swizzle_elems)
a_tma_rows​
comptime a_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0
accum_pipeline_consumer_arv_count​
comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[1]
accum_pipeline_producer_arv_count​
comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[0]
accum_type​
comptime accum_type = DType.float32
ADescLayout​
comptime ADescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
ATileLayout​
comptime ATileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
ATmaOp​
comptime ATmaOp = TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
b_expected_bytes​
comptime b_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[b_type]())
b_smem_layout​
comptime b_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK == (config.b_swizzle.bytes() // size_of[b_type]())) else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).to_layout() if transpose_b else Layout(Coord(Coord(Idx[8](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN == (config.b_swizzle.bytes() // size_of[b_type]())) else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).transpose().to_layout()
b_swizzle_elems​
comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())
b_tile_dim0​
comptime b_tile_dim0 = compute_tma_tile_dims[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[1]
b_tma_load_size​
comptime b_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_swizzle_elems)
b_tma_rows​
comptime b_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0
BDescLayout​
comptime BDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
BK​
comptime BK = config.block_tile_shape[2]
BM​
comptime BM = config.block_tile_shape[0]
BN​
comptime BN = config.block_tile_shape[1]
BTileLayout​
comptime BTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
BTmaOp​
comptime BTmaOp = TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
c_smem_layout​
comptime c_smem_layout = Layout.row_major(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputN)
c_swizzle_elems​
comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())
c_tile_dim0​
comptime c_tile_dim0 = compute_tma_tile_dims[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[2]
c_tile_dim1​
comptime c_tile_dim1 = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems if config.AB_swapped else GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputN
CDescLayout​
comptime CDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
clc_consumer_arv_count​
comptime clc_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group)
clc_producer_arv_count​
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count​
comptime clc_throttle_consumer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SCHEDULER_THREADS
clc_throttle_producer_arv_count​
comptime clc_throttle_producer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TMA_LOAD_THREADS
CLUSTER_M​
comptime CLUSTER_M = config.cluster_shape[0]
CLUSTER_N​
comptime CLUSTER_N = config.cluster_shape[1]
CLUSTER_SIZE​
comptime CLUSTER_SIZE = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N)
Context​
comptime Context = KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N]
cta_group​
comptime cta_group = config.cta_group
CTileLayout​
comptime CTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
CTmaOp​
comptime CTmaOp = TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
EPILOGUE_THREADS​
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtx​
comptime EpilogueCtx = EpilogueWarpContext[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS]
GroupPtrLayout​
comptime GroupPtrLayout = Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
GroupPtrTile​
comptime GroupPtrTile = TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin]
input_expected_bytes​
comptime input_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group * (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_expected_bytes + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfa_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfb_expected_bytes)) * config)
InputTilePipelineType​
comptime InputTilePipelineType = InputTilePipeline[BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size]
MMA_K​
comptime MMA_K = config.mma_shape[2]
MMA_M​
comptime MMA_M = config.mma_shape[0]
MMA_N​
comptime MMA_N = config.mma_shape[1]
MMA_THREADS​
comptime MMA_THREADS = WARP_SIZE
MmaCtx​
comptime MmaCtx = MmaWarpContext[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS]
MmaEpilogueSync​
comptime MmaEpilogueSync = WarpGroupBarrier[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS), 1]
MmaOp​
comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages​
comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages
num_clc_pipeline_stages_2sm​
comptime num_clc_pipeline_stages_2sm = 2
num_group_pipeline_stages​
comptime num_group_pipeline_stages = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages // config)
num_output_stages​
comptime num_output_stages = config.num_output_stages
num_output_warps​
comptime num_output_warps = 4
num_pipeline_stages​
comptime num_pipeline_stages = config.num_pipeline_stages
NUM_THREADS​
comptime NUM_THREADS = (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SCHEDULER_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TMA_LOAD_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS)
NUM_TMEM_COLS​
comptime NUM_TMEM_COLS = 512
opc​
comptime opc = OutputPipelineConfig(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group)
OutputM​
comptime OutputM = config.output_tile_shape[0]
OutputN​
comptime OutputN = config.output_tile_shape[1]
OutputPipeline​
comptime OutputPipeline = OutputTilePipeline[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc]
ProblemSizesLayout​
comptime ProblemSizesLayout = Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]]
ProblemSizesTile​
comptime ProblemSizesTile = TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin]
register_based_epilogue​
comptime register_based_epilogue = config.register_based_epilogue
SCHEDULER_THREADS​
comptime SCHEDULER_THREADS = WARP_SIZE
SchedulerType​
comptime SchedulerType = GroupedTileScheduler[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK, max_groups]
SF_K_GROUP_SIZE​
comptime SF_K_GROUP_SIZE = (4 * config)
sfa_expected_bytes​
comptime sfa_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfa_smem_layout.size() * size_of[sfa_dtype]())
SFA_NUM_COLS​
comptime SFA_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // 32))
sfa_smem_layout​
comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SFADescLayout​
comptime SFADescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
SFATileLayout​
comptime SFATileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
SFATmaOp​
comptime SFATmaOp = TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
sfb_expected_bytes​
comptime sfb_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfb_smem_layout.size() * size_of[sfb_dtype]())
SFB_NUM_COLS​
comptime SFB_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // 32))
sfb_smem_layout​
comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SFBDescLayout​
comptime SFBDescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
SFBTileLayout​
comptime SFBTileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]
SFBTmaOp​
comptime SFBTmaOp = TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
SmemType​
comptime SmemType = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]
stage_stride_cols​
comptime stage_stride_cols = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N
TENSORMAP_AB_INIT_BARRIER_ID​
comptime TENSORMAP_AB_INIT_BARRIER_ID = 3
TENSORMAP_AB_INIT_THREADS​
comptime TENSORMAP_AB_INIT_THREADS = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TMA_LOAD_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS)
TensormapAbInitBarrier​
comptime TensormapAbInitBarrier = WarpGroupBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TENSORMAP_AB_INIT_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TENSORMAP_AB_INIT_BARRIER_ID]
TensormapManagerType​
comptime TensormapManagerType = GroupedTensormapManager
TilePayload​
comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages]
TileWriterType​
comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc, config.c_swizzle, config.AB_swapped, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputN, config.num_output_stages, 4, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, register_based_epilogue=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].register_based_epilogue, batched=True]
TMA_LOAD_THREADS​
comptime TMA_LOAD_THREADS = WARP_SIZE
TMATensorTileArrayA​
comptime TMATensorTileArrayA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
TMATensorTileArrayB​
comptime TMATensorTileArrayB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
TMATensorTileArrayC​
comptime TMATensorTileArrayC = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
TMATensorTileArraySFA​
comptime TMATensorTileArraySFA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
TMATensorTileArraySFB​
comptime TMATensorTileArraySFB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]
Tmem​
comptime Tmem = TmemAllocation[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]
TmemDealloc​
comptime TmemDealloc = TmemDeallocBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]
TmemRegion​
comptime TmemRegion = BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]
Methods​
validate_config​
static validate_config()
Compile-time validation of kernel configuration.
init_barriers​
static init_barriers(ctx: KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N], a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], input_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], tmem_dealloc: SMemArray[SharedMemBarrier, 1])
Initialize barriers and prefetch TMA descriptors (1SM path, no CLC).
init_barriers_2sm​
static init_barriers_2sm(ctx: KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N], a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], input_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (config * 2)], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])
Initialize barriers and prefetch TMA descriptors (2SM path, with CLC).
run​
static run(a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)
Grouped block-scaled GEMM kernel entry point.
This kernel processes multiple GEMM problems (groups) with dynamic tensormap updates at group boundaries.
load_input_tiles​
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_op: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_op: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_op: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], tiles: ProducerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)
Load A, B, SFA, SFB tiles using TMA with InputProducerStage.
mma​
static mma[tiles_origin: MutOrigin, //](tiles: ConsumerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)
Execute MMA operations using ConsumerTiles.
epilogue​
static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], stage: OutputStage[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32 = 1)
Execute epilogue to store accumulated results.
run_2sm​
static run_2sm(a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)
Grouped block-scaled GEMM kernel with 2SM (cta_group=2) support.
This entry point uses CLC-based work distribution for proper 2SM synchronization between CTAs in a cluster. Both CTAs cooperate on each tile, with one CTA doing MMA work and both doing TMA loads.
Architecture matches the working block_scaled_matmul_kernel:
- Scheduler warp: Produces work items via CLC barriers
- TMA warp: Loads tiles with tensormap updates on group change
- MMA warp: Waits on CLC, executes MMA (elected CTA only)
- Epilogue warps: Stores results with tensormap updates
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!