Mojo struct
BlackwellBlockScaledMatmulKernel
struct BlackwellBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]
Block-scaled matmul kernel V3 - ported from working legacy kernel.
This struct provides the structured interface while internally using the proven legacy kernel logic.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_expected_bytes
comptime a_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_smem_layout.size() * size_of[a_type]())
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.a_swizzle]()
a_tma_load_size
comptime a_tma_load_size = a_desc_layout.size()
a_tma_rows
comptime a_tma_rows = a_desc_layout.shape[1].value()
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = DType.float32
b_expected_bytes
comptime b_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_smem_layout.size() * size_of[b_type]())
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]()
b_tma_load_size
comptime b_tma_load_size = b_desc_layout.size()
b_tma_rows
comptime b_tma_rows = b_desc_layout.shape[1].value()
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
c_smem_layout
comptime c_smem_layout = Layout.row_major(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].OutputN)
clc_consumer_arv_count
comptime clc_consumer_arv_count = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE * ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)))
clc_producer_arv_count
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count
comptime clc_throttle_consumer_arv_count = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS
clc_throttle_producer_arv_count
comptime clc_throttle_producer_arv_count = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS
CLUSTER_M
comptime CLUSTER_M = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](0))
CLUSTER_N
comptime CLUSTER_N = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](1))
CLUSTER_SIZE
comptime CLUSTER_SIZE = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)
Context
comptime Context = KernelContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]
cta_group
comptime cta_group = config.cta_group
EPILOGUE_THREADS
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[Int(config), BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]
input_expected_bytes
comptime input_expected_bytes = ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_expected_bytes + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_expected_bytes) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfa_expected_bytes) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfb_expected_bytes)) * Int(config))
InputTilePipeline
comptime InputTilePipeline = InputTilePipeline[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, Int(config)]
max_tmem_cols
comptime max_tmem_cols = 512
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
MMA_THREADS
comptime MMA_THREADS = WARP_SIZE
MmaCtx
comptime MmaCtx = MmaWarpContext[Int(config), BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]
MmaEpilogueSync
comptime MmaEpilogueSync = WarpGroupBarrier[(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS), 1]
MmaOp
comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = Int(config)
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = Int(config)
num_group_pipeline_stages
comptime num_group_pipeline_stages = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // Int(config))
num_output_stages
comptime num_output_stages = Int(config)
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = Int(config)
NUM_THREADS
comptime NUM_THREADS = (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[Int(config), BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
Scheduler
comptime Scheduler = TileScheduler[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]
SCHEDULER_THREADS
comptime SCHEDULER_THREADS = WARP_SIZE
SF_K_GROUP_SIZE
comptime SF_K_GROUP_SIZE = (4 * config)
sfa_expected_bytes
comptime sfa_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfa_smem_layout.size() * size_of[sfa_dtype]())
SFA_NUM_COLS
comptime SFA_NUM_COLS = (config * (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 32))
sfa_smem_layout
comptime sfa_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
sfb_expected_bytes
comptime sfb_expected_bytes = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfb_smem_layout.size() * size_of[sfb_dtype]())
SFB_NUM_COLS
comptime SFB_NUM_COLS = (config * (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N // 32))
sfb_smem_layout
comptime sfb_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SF_K_GROUP_SIZE * config), config.vec_sf_size]()
SmemType
comptime SmemType = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N
TilePayload
comptime TilePayload = BlockScaledTilePayload[a_type, b_type, sfa_dtype, sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.sfa_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.sfb_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages]
TileWriterType
comptime TileWriterType = BlockScaledTileWriter[a_type, DType.float32, config.block_tile_shape, config.mma_shape, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, Int(config), config.c_swizzle, config.AB_swapped, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.c_smem_layout, Int(config), BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, 4]
TMA_LOAD_THREADS
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmem
comptime Tmem = TmemAllocation[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
TmemDealloc
comptime TmemDealloc = TmemDeallocBarrier[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
Methods
load_input_tiles
static load_input_tiles[tiles_origin: MutOrigin, //](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], tiles: InputProducerStage[tiles_origin, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, Int(config)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)
Load A, B, SFA, SFB tiles using TMA with InputProducerStage.
This method uses the structured ProducerStage pattern from matmul_kernels.mojo, with tiles and barrier encapsulated in the stage.
Args:
- a_tma_op (
TMATensorTile): TMA descriptor for A matrix. - b_tma_op (
TMATensorTile): TMA descriptor for B matrix. - sfa_tma_op (
TMATensorTile): TMA descriptor for A scaling factors. - sfb_tma_op (
TMATensorTile): TMA descriptor for B scaling factors. - tiles (
InputProducerStage): ProducerStage context with encapsulated tile access. - peer_cta_coord (
Tuple): (rank_n, rank_m, peer_m_rank) for peer CTA slicing. - work_tile_coord (
Tuple): (m, n, k_start) coordinates of the work tile. - a_multicast_mask (
UInt16): Multicast mask for A tiles. - b_multicast_mask (
UInt16): Multicast mask for B tiles. - iter_idx (
UInt32): K iteration index (base index for k_group). - elect_one_cta (
Bool): True if this CTA should call expect_bytes.
mma
static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TilePayload, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, Int(config)], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.scaling_kind, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], tmem_addr: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, iter_idx: UInt32, k_start: UInt32)
Execute MMA operations using InputConsumerStage.
This method uses the structured ConsumerStage pattern from matmul_kernels.mojo, with tiles and barrier encapsulated in the stage.
Args:
- tiles (
InputConsumerStage): ConsumerStage context with encapsulated tile access. - mma_op (
MmaOpSM100_BlockScaled_SS): Block-scaled MMA operation instance. - tmem_addr (
UInt32): TMEM address for accumulators. - sfa_tmem (
UInt32): TMEM base address for A scaling factors. - sfb_tmem (
UInt32): TMEM base address for B scaling factors. - iter_idx (
UInt32): K iteration index. - k_start (
UInt32): Starting K iteration (for init_c determination).
epilogue
static epilogue(c_tiles: SMemTileArrayType[c_type, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].c_smem_layout, BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_output_stages, 128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], mma_output_pipeline: ProducerConsumerPipeline[Int(config)], tmem_addr: UInt32, work_tile_coord: Tuple[UInt32, UInt32, UInt32], elect_one_warp: Bool, M: UInt32, N: UInt32)
Execute epilogue to store accumulated results to global memory.
Uses BlockScaledTileWriter which encapsulates:
- TmemArrayType.load_fragments() for TMEM load
- AccumBarrier.arrive() for barrier signaling
- TMEMToSMemWriter.write_fragments() for SMEM write
- 3D TMA store (M, N, Batch coordinates)
- tma_wait_pipelined() for TMA wait
Args:
- c_tiles (
SMemTileArrayType): SMEM tile array for C output. - c_tma_op (
TMATensorTile): TMA descriptor for C matrix. - mma_output_pipeline (
ProducerConsumerPipeline): Pipeline for MMA→epilogue sync. - tmem_addr (
UInt32): Base TMEM address for accumulators. - work_tile_coord (
Tuple): (m, n, k_start) coordinates. - elect_one_warp (
Bool): Whether this warp should execute (unused). - M (
UInt32): Problem M dimension. - N (
UInt32): Problem N dimension.
validate_config
static validate_config()
Validate configuration constraints at compile time.
run
static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])
Kernel entry point - ported from legacy kernel.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!