Skip to main content

Mojo struct

GroupedBlockScaledMatmulKernel

struct GroupedBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], max_groups: Int, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1)), elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None]

Grouped block-scaled matmul kernel with dynamic tensormap updates.

This kernel extends BlackwellBlockScaledMatmulKernel to support grouped GEMM:

  • Uses GroupedTileScheduler for linear tile iteration across groups
  • Uses GroupedTensormapManager for per-block updatable TMA descriptors
  • Updates tensormaps when transitioning between groups

Architecture (aligned with NVIDIA CuTe DSL grouped_blockscaled_gemm.py):

  • TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions
  • MMA warp: Waits for tensormap init, consumes tiles, performs block-scaled MMA
  • Epilogue warps: Initializes C tensormap, handles C group transitions

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

a_expected_bytes​

comptime a_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[a_type]())

a_smem_layout​

comptime a_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // 8)]()), Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK // (config.a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (config.a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK == (config.a_swizzle.bytes() // size_of[a_type]())) else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM * (config.a_swizzle.bytes() // size_of[a_type]()))]()))).to_layout()

a_swizzle_elems​

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim0​

comptime a_tile_dim0 = compute_tma_tile_dims[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[0]

a_tma_load_size​

comptime a_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_swizzle_elems)

a_tma_rows​

comptime a_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0

accum_pipeline_consumer_arv_count​

comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[1]

accum_pipeline_producer_arv_count​

comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group]()[0]

accum_type​

comptime accum_type = DType.float32

ADescLayout​

comptime ADescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

ATileLayout​

comptime ATileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

ATmaOp​

comptime ATmaOp = TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

b_expected_bytes​

comptime b_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK) * size_of[b_type]())

b_smem_layout​

comptime b_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK == (config.b_swizzle.bytes() // size_of[b_type]())) else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).to_layout() if transpose_b else Layout(Coord(Coord(Idx[8](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN == (config.b_swizzle.bytes() // size_of[b_type]())) else (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).transpose().to_layout()

b_swizzle_elems​

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim0​

comptime b_tile_dim0 = compute_tma_tile_dims[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[1]

b_tma_load_size​

comptime b_tma_load_size = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0 * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_swizzle_elems)

b_tma_rows​

comptime b_tma_rows = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0

BDescLayout​

comptime BDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

comptime BK = config.block_tile_shape[2]

comptime BM = config.block_tile_shape[0]

comptime BN = config.block_tile_shape[1]

BTileLayout​

comptime BTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

BTmaOp​

comptime BTmaOp = TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

c_smem_layout​

comptime c_smem_layout = Layout.row_major(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputN)

c_swizzle_elems​

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim0​

comptime c_tile_dim0 = compute_tma_tile_dims[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, config.AB_swapped]()[2]

c_tile_dim1​

comptime c_tile_dim1 = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_swizzle_elems if config.AB_swapped else GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].OutputN

CDescLayout​

comptime CDescLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

clc_consumer_arv_count​

comptime clc_consumer_arv_count = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group)

clc_producer_arv_count​

comptime clc_producer_arv_count = 1

clc_throttle_consumer_arv_count​

comptime clc_throttle_consumer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SCHEDULER_THREADS

clc_throttle_producer_arv_count​

comptime clc_throttle_producer_arv_count = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TMA_LOAD_THREADS

CLUSTER_M​

comptime CLUSTER_M = config.cluster_shape[0]

CLUSTER_N​

comptime CLUSTER_N = config.cluster_shape[1]

CLUSTER_SIZE​

comptime CLUSTER_SIZE = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M * GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N)

Context​

comptime Context = KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N]

cta_group​

comptime cta_group = config.cta_group

CTileLayout​

comptime CTileLayout = Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

CTmaOp​

comptime CTmaOp = TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

EPILOGUE_THREADS​

comptime EPILOGUE_THREADS = (4 * WARP_SIZE)

EpilogueCtx​

comptime EpilogueCtx = EpilogueWarpContext[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS]

GroupPtrLayout​

comptime GroupPtrLayout = Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

GroupPtrTile​

comptime GroupPtrTile = TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin]

input_expected_bytes​

comptime input_expected_bytes = ((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group * (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_expected_bytes + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfa_expected_bytes) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfb_expected_bytes)) * config)

InputTilePipelineType​

comptime InputTilePipelineType = InputTilePipeline[BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size]

MMA_K​

comptime MMA_K = config.mma_shape[2]

MMA_M​

comptime MMA_M = config.mma_shape[0]

MMA_N​

comptime MMA_N = config.mma_shape[1]

MMA_THREADS​

comptime MMA_THREADS = WARP_SIZE

MmaCtx​

comptime MmaCtx = MmaWarpContext[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS]

MmaEpilogueSync​

comptime MmaEpilogueSync = WarpGroupBarrier[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS), 1]

MmaOp​

comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages​

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages_2sm​

comptime num_clc_pipeline_stages_2sm = 2

num_group_pipeline_stages​

comptime num_group_pipeline_stages = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages // config)

num_output_stages​

comptime num_output_stages = config.num_output_stages

num_output_warps​

comptime num_output_warps = 4

num_pipeline_stages​

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS​

comptime NUM_THREADS = (((GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SCHEDULER_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TMA_LOAD_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS) + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].EPILOGUE_THREADS)

NUM_TMEM_COLS​

comptime NUM_TMEM_COLS = 512

comptime opc = OutputPipelineConfig(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].stage_stride_cols, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group)

OutputM​

comptime OutputM = config.output_tile_shape[0]

OutputN​

comptime OutputN = config.output_tile_shape[1]

OutputPipeline​

comptime OutputPipeline = OutputTilePipeline[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc]

ProblemSizesLayout​

comptime ProblemSizesLayout = Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]]

ProblemSizesTile​

comptime ProblemSizesTile = TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin]

register_based_epilogue​

comptime register_based_epilogue = config.register_based_epilogue

SCHEDULER_THREADS​

comptime SCHEDULER_THREADS = WARP_SIZE

SchedulerType​

comptime SchedulerType = GroupedTileScheduler[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK, max_groups]

SF_K_GROUP_SIZE​

comptime SF_K_GROUP_SIZE = (4 * config)

sfa_expected_bytes​

comptime sfa_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfa_smem_layout.size() * size_of[sfa_dtype]())

SFA_NUM_COLS​

comptime SFA_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // 32))

sfa_smem_layout​

comptime sfa_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFADescLayout​

comptime SFADescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

SFATileLayout​

comptime SFATileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

SFATmaOp​

comptime SFATmaOp = TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

sfb_expected_bytes​

comptime sfb_expected_bytes = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].sfb_smem_layout.size() * size_of[sfb_dtype]())

SFB_NUM_COLS​

comptime SFB_NUM_COLS = (config * (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // 32))

sfb_smem_layout​

comptime sfb_smem_layout = tile_sf_layout_k_major[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N, (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFBDescLayout​

comptime SFBDescLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

SFBTileLayout​

comptime SFBTileLayout = Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

SFBTmaOp​

comptime SFBTmaOp = TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

SmemType​

comptime SmemType = GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]

stage_stride_cols​

comptime stage_stride_cols = GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N

TENSORMAP_AB_INIT_BARRIER_ID​

comptime TENSORMAP_AB_INIT_BARRIER_ID = 3

TENSORMAP_AB_INIT_THREADS​

comptime TENSORMAP_AB_INIT_THREADS = (GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TMA_LOAD_THREADS + GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_THREADS)

TensormapAbInitBarrier​

comptime TensormapAbInitBarrier = WarpGroupBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TENSORMAP_AB_INIT_THREADS, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].TENSORMAP_AB_INIT_BARRIER_ID]

TensormapManagerType​

comptime TensormapManagerType = GroupedTensormapManager

TilePayload​

comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages]

TileWriterType​

comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc, config.c_swizzle, config.AB_swapped, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.OutputN, config.num_output_stages, 4, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, register_based_epilogue=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].register_based_epilogue, batched=True]

TMA_LOAD_THREADS​

comptime TMA_LOAD_THREADS = WARP_SIZE

TMATensorTileArrayA​

comptime TMATensorTileArrayA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

TMATensorTileArrayB​

comptime TMATensorTileArrayB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

TMATensorTileArrayC​

comptime TMATensorTileArrayC = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

TMATensorTileArraySFA​

comptime TMATensorTileArraySFA = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

TMATensorTileArraySFB​

comptime TMATensorTileArraySFB = TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()]

Tmem​

comptime Tmem = TmemAllocation[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]

TmemDealloc​

comptime TmemDealloc = TmemDeallocBarrier[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc.cta_group]

TmemRegion​

comptime TmemRegion = BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles]

Methods​

validate_config​

static validate_config()

Compile-time validation of kernel configuration.

init_barriers​

static init_barriers(ctx: KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N], a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], input_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], tmem_dealloc: SMemArray[SharedMemBarrier, 1])

Initialize barriers and prefetch TMA descriptors (1SM path, no CLC).

init_barriers_2sm​

static init_barriers_2sm(ctx: KernelContext[0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_N], a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], input_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (GroupedBlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (config * 2)], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])

Initialize barriers and prefetch TMA descriptors (2SM path, with CLC).

run​

static run(a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)

Grouped block-scaled GEMM kernel entry point.

This kernel processes multiple GEMM problems (groups) with dynamic tensormap updates at group boundaries.

load_input_tiles​

static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_op: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_op: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_op: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], tiles: ProducerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)

Load A, B, SFA, SFB tiles using TMA with InputProducerStage.

mma​

static mma[tiles_origin: MutOrigin, //](tiles: ConsumerTiles[tiles_origin, BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BN, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM0, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_pipeline_stages], GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, tmem_region: BlockScaledTmem[DType.float32, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_M, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_accum_pipeline_stages, sfa_dtype, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM, GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].num_pipeline_stages, cta_group=GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].cta_group, num_sf_k_tiles=config.num_sf_k_tiles], iter_idx: UInt32, k_start: UInt32)

Execute MMA operations using ConsumerTiles.

epilogue​

static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], stage: OutputStage[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].opc], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32 = 1)

Execute epilogue to store accumulated results.

run_2sm​

static run_2sm(a_tma_template: TMATensorTile[a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], b_tma_template: TMATensorTile[b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], c_tma_template: TMATensorTile[c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfa_tma_template: TMATensorTile[sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], sfb_tma_template: TMATensorTile[sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_a: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, a_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_b: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, b_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfa: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfa_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_sfb: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, sfb_dtype, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[(GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], device_tma_c: TMATensorTileArray[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].CLUSTER_SIZE, c_type, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]](), _to_index_list[Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]].rank, Layout[ComptimeInt[1], ComptimeInt[GroupedBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, max_groups, cluster_shape, elementwise_compute_lambda_fn].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]]()], group_a_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_b_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_c_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfa_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], group_sfb_ptrs_lt: TileTensor[DType.uint64, Layout[ComptimeInt[max_groups], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]], MutAnyOrigin], problem_sizes_lt: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int)

Grouped block-scaled GEMM kernel with 2SM (cta_group=2) support.

This entry point uses CLC-based work distribution for proper 2SM synchronization between CTAs in a cluster. Both CTAs cooperate on each tile, with one CTA doing MMA work and both doing TMA loads.

Architecture matches the working block_scaled_matmul_kernel:

  • Scheduler warp: Produces work items via CLC barriers
  • TMA warp: Loads tiles with tensormap updates on group change
  • MMA warp: Waits on CLC, executes MMA (elected CTA only)
  • Epilogue warps: Stores results with tensormap updates

Was this page helpful?