Skip to main content

Mojo struct

BlackwellBlockScaledMatmulKernel

struct BlackwellBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1)), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]

Block-scaled matmul kernel V3 - ported from working legacy kernel.

This struct provides the structured interface while internally using the proven legacy kernel logic.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_expected_bytes

comptime a_expected_bytes = ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK) * size_of[a_type]())

a_internal_layout

comptime a_internal_layout = Layout(Coord(Coord(Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // 8)](), Idx[8]()), Coord(Idx[(128 // size_of[a_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK * size_of[a_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[a_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // 8) * (128 // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0]())))

a_smem_layout

comptime a_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // 8)]()), Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // (config.a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(config.a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (config.a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK == (config.a_swizzle.bytes() // size_of[a_type]())) else (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM * (config.a_swizzle.bytes() // size_of[a_type]()))]()))).to_layout()

a_swizzle_elems

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim0

comptime a_tile_dim0 = compute_tma_tile_dims[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, config.AB_swapped]()[0]

a_tma_load_size

comptime a_tma_load_size = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0 * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_swizzle_elems)

a_tma_rows

comptime a_tma_rows = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[1]

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[0]

accum_type

comptime accum_type = DType.float32

ADescLayout

comptime ADescLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0], ComptimeInt[(config.a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

ATileLayout

comptime ATileLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_tile_dim0].static_value * ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

ATmaOp

comptime ATmaOp = TMATensorTile[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()]

b_expected_bytes

comptime b_expected_bytes = ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK) * size_of[b_type]())

b_internal_layout

comptime b_internal_layout = Layout(Coord(Coord(Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // 8)](), Idx[8]()), Coord(Idx[(128 // size_of[b_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK * size_of[b_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[b_type]())](), Idx[((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // 8) * (128 // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0]())))

b_smem_layout

comptime b_smem_layout = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK == (config.b_swizzle.bytes() // size_of[b_type]())) else (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).to_layout() if transpose_b else Layout(Coord(Coord(Idx[8](), Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK // 8)]()), Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN // (config.b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(config.b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (config.b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN == (config.b_swizzle.bytes() // size_of[b_type]())) else (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK * (config.b_swizzle.bytes() // size_of[b_type]()))]()))).transpose().to_layout()

b_swizzle_elems

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim0

comptime b_tile_dim0 = compute_tma_tile_dims[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, config.AB_swapped]()[1]

b_tma_load_size

comptime b_tma_load_size = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0 * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_swizzle_elems)

b_tma_rows

comptime b_tma_rows = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0

BDescLayout

comptime BDescLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0], ComptimeInt[(config.b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

BK

comptime BK = config.block_tile_shape[2]

BM

comptime BM = config.block_tile_shape[0]

BN

comptime BN = config.block_tile_shape[1]

BTileLayout

comptime BTileLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK], ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_tile_dim0].static_value * ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BK].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

BTmaOp

comptime BTmaOp = TMATensorTile[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()]

c_smem_layout

comptime c_smem_layout = Layout.row_major(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputN)

c_swizzle_elems

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim0

comptime c_tile_dim0 = compute_tma_tile_dims[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, config.AB_swapped]()[2]

c_tile_dim1

comptime c_tile_dim1 = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].OutputN if not config.AB_swapped.__bool__() else BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_swizzle_elems

CDescLayout

comptime CDescLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0], ComptimeInt[(config.c_swizzle.bytes() // size_of[c_type]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

clc_consumer_arv_count

comptime clc_consumer_arv_count = compute_clc_barrier_counts[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[1]

clc_producer_arv_count

comptime clc_producer_arv_count = compute_clc_barrier_counts[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[0]

clc_throttle_consumer_arv_count

comptime clc_throttle_consumer_arv_count = compute_clc_barrier_counts[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[3]

clc_throttle_producer_arv_count

comptime clc_throttle_producer_arv_count = compute_clc_barrier_counts[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group]()[2]

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape[0]

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape[1]

CLUSTER_SIZE

comptime CLUSTER_SIZE = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)

Context

comptime Context = KernelContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]

cta_group

comptime cta_group = config.cta_group

CTileLayout

comptime CTileLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0], ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1], ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim0].static_value * ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].c_tile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

CTmaOp

comptime CTmaOp = TMATensorTile[c_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout]()]

EPILOGUE_THREADS

comptime EPILOGUE_THREADS = (4 * WARP_SIZE)

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]

input_expected_bytes

comptime input_expected_bytes = ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group * (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].a_expected_bytes + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].b_expected_bytes) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].sfa_expected_bytes) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].sfb_expected_bytes)) * config)

InputTilePipeline

comptime InputTilePipeline = InputTilePipeline[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.num_group_pipeline_stages, config.k_group_size]

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

MMA_THREADS

comptime MMA_THREADS = WARP_SIZE

MmaCtx

comptime MmaCtx = MmaWarpContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]

MmaEpilogueSync

comptime MmaEpilogueSync = WarpGroupBarrier[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS), 1]

MmaOp

comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

opc

comptime opc = OutputPipelineConfig(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group)

OutputM

comptime OutputM = config.output_tile_shape[0]

OutputN

comptime OutputN = config.output_tile_shape[1]

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc]

Scheduler

comptime Scheduler = TileScheduler[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[0], config.cluster_shape[1], config.cluster_shape[2]), config.raster_order, config.block_swizzle_size]

SCHEDULER_THREADS

comptime SCHEDULER_THREADS = WARP_SIZE

SF_K_GROUP_SIZE

comptime SF_K_GROUP_SIZE = (4 * config)

sfa_expected_bytes

comptime sfa_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].sfa_smem_layout.size() * size_of[sfa_dtype]())

SFA_NUM_COLS

comptime SFA_NUM_COLS = (config * (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // 32))

sfa_smem_layout

comptime sfa_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM, (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFADescLayout

comptime SFADescLayout = Layout[ComptimeInt[1], ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfa_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

SFATileLayout

comptime SFATileLayout = Layout[ComptimeInt[1], ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BM // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

SFATmaOp

comptime SFATmaOp = TMATensorTile[sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFADescLayout]()]

sfb_expected_bytes

comptime sfb_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].sfb_smem_layout.size() * size_of[sfb_dtype]())

SFB_NUM_COLS

comptime SFB_NUM_COLS = (config * (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N // 32))

sfb_smem_layout

comptime sfb_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N, (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SF_K_GROUP_SIZE * config), config.vec_sf_size]()

SFBDescLayout

comptime SFBDescLayout = Layout[ComptimeInt[1], ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[(TensorMapSwizzle.SWIZZLE_NONE.bytes() // size_of[sfb_dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

SFBTileLayout

comptime SFBTileLayout = Layout[ComptimeInt[1], ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N // SF_MN_GROUP_SIZE)], ComptimeInt[config.num_sf_k_tiles], ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())], ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)], ComptimeInt[(ComptimeInt[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N // SF_MN_GROUP_SIZE)].static_value * ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[config.num_sf_k_tiles].static_value * ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)].static_value)], ComptimeInt[(ComptimeInt[(load_from_mem SF_ATOM_M.__getitem_param__[0]())].static_value * ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)].static_value)], ComptimeInt[(ComptimeInt[((load_from_mem SF_ATOM_M.__getitem_param__[1]()) * 4)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]]

SFBTmaOp

comptime SFBTmaOp = TMATensorTile[sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBDescLayout]()]

SmemType

comptime SmemType = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]

stage_stride_cols

comptime stage_stride_cols = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].MMA_N

TilePayload

comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, IndexList(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.SFA_DIM0, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.SFA_DIM1, __list_literal__=Tuple()), IndexList(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.SFB_DIM0, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.SFB_DIM1, __list_literal__=Tuple()), BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.num_pipeline_stages]

TileWriterType

comptime TileWriterType = TileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc, config.c_swizzle, config.AB_swapped, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.OutputN, config.num_output_stages, 4, batched=True]

TMA_LOAD_THREADS

comptime TMA_LOAD_THREADS = WARP_SIZE

Tmem

comptime Tmem = TmemAllocation[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc.cta_group]

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc.cta_group]

Methods

load_input_tiles

static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()], sfa_tma_op: TMATensorTile[sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFADescLayout]()], sfb_tma_op: TMATensorTile[sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBDescLayout]()], tiles: ProducerTiles[tiles_origin, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)

Load A, B, SFA, SFB tiles using TMA with ProducerTiles.

This method uses the structured ProducerStage pattern from matmul_kernels.mojo, with tiles and barrier encapsulated in the stage.

Args:

  • a_tma_op (TMATensorTile): TMA descriptor for A matrix.
  • b_tma_op (TMATensorTile): TMA descriptor for B matrix.
  • sfa_tma_op (TMATensorTile): TMA descriptor for A scaling factors.
  • sfb_tma_op (TMATensorTile): TMA descriptor for B scaling factors.
  • tiles (ProducerTiles): ProducerStage context with encapsulated tile access.
  • peer_cta_coord (Tuple): (rank_n, rank_m, peer_m_rank) for peer CTA slicing.
  • work_tile_coord (Tuple): (m, n, k_start) coordinates of the work tile.
  • a_multicast_mask (UInt16): Multicast mask for A tiles.
  • b_multicast_mask (UInt16): Multicast mask for B tiles.
  • iter_idx (UInt32): K iteration index (base index for k_group).
  • elect_one_cta (Bool): True if this CTA should call expect_bytes.

mma

static mma[tiles_origin: MutOrigin, //](tiles: ConsumerTiles[tiles_origin, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, iter_idx: UInt32, k_start: UInt32)

Execute MMA operations using ConsumerTiles.

This method uses the structured ConsumerStage pattern from matmul_kernels.mojo, with tiles and barrier encapsulated in the stage.

Args:

  • tiles (ConsumerTiles): ConsumerStage context with encapsulated tile access.
  • mma_op (MmaOpSM100_BlockScaled_SS): Block-scaled MMA operation instance.
  • tmem_addr (UInt32): TMEM address for accumulators.
  • sfa_tmem (UInt32): TMEM base address for A scaling factors.
  • sfb_tmem (UInt32): TMEM base address for B scaling factors.
  • iter_idx (UInt32): K iteration index.
  • k_start (UInt32): Starting K iteration (for init_c determination).

epilogue

static epilogue(c_tiles: SMemTileArray2DRowMajor[c_type, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN, BlockScaledTileCore[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages], c_tma_op: TMATensorTile[c_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout]()], stage: OutputStage[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].opc], work_tile_coord: Tuple[UInt32, UInt32, UInt32], M: UInt32, N: UInt32, alpha: Float32)

Execute epilogue to store accumulated results to global memory.

Uses TileWriter which encapsulates:

  • TmemArrayType.load_fragments() for TMEM load
  • AccumBarrier.arrive() for barrier signaling
  • TMEMToSMemWriter.write_fragments() for SMEM write
  • 3D TMA store (M, N, Batch coordinates)
  • tma_wait_pipelined() for TMA wait

Barrier synchronization (wait/step) is handled by caller via consumer() context.

Args:

  • c_tiles (SMemTileArray2DRowMajor): SMEM tile array for C output.
  • c_tma_op (TMATensorTile): TMA descriptor for C matrix.
  • stage (OutputStage): OutputStage from consumer() context with pipeline, index, and TMEM.
  • work_tile_coord (Tuple): (m, n, k_start) coordinates.
  • M (UInt32): Problem M dimension.
  • N (UInt32): Problem N dimension.
  • alpha (Float32): Tensor scale factor (scalar).

validate_config

static validate_config()

Validate configuration constraints at compile time.

init_barriers

static init_barriers(ctx: KernelContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N], a_tma_op: TMATensorTile[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()], c_tma_op: TMATensorTile[c_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout]()], sfa_tma_op: TMATensorTile[sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFADescLayout]()], sfb_tma_op: TMATensorTile[sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBDescLayout]()], input_barriers: SMemArray[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (config * 2)], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])

Initialize barriers and prefetch TMA descriptors.

run

static run(a_tma_op: TMATensorTile[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].BDescLayout]()], c_tma_op: TMATensorTile[c_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].CDescLayout]()], sfa_tma_op: TMATensorTile[sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFATileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFADescLayout]()], sfb_tma_op: TMATensorTile[sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout](), _to_index_list[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBTileLayout.rank, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, pdl_level, max_profiled_tiles_per_SM].SFBDescLayout]()], alpha: Float32, cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])

Kernel entry point - ported from legacy kernel.

Was this page helpful?