Mojo struct
BlackwellBlockScaledMatmulKernel
struct BlackwellBlockScaledMatmulKernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]
SM100 block-scaled GEMM kernel for MXFP8 (FP8 with microscaling).
Extends standard matmul with per-block scaling factors (SFA, SFB) that are loaded via TMA, copied to TMEM, and applied during MMA operations.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.a_swizzle]()
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b].accum_type
AccumTensor
comptime AccumTensor = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TmemRegion.AccumTile
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]()
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
c_smem_layout
comptime c_smem_layout = Layout.row_major(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].OutputM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].OutputN)
clc_consumer_arv_count
comptime clc_consumer_arv_count = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE * ((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)))
clc_producer_arv_count
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count
comptime clc_throttle_consumer_arv_count = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS
clc_throttle_producer_arv_count
comptime clc_throttle_producer_arv_count = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS
CLUSTER_M
comptime CLUSTER_M = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](0))
CLUSTER_N
comptime CLUSTER_N = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](1))
CLUSTER_SIZE
comptime CLUSTER_SIZE = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)
ContextType
comptime ContextType = BlockScaledKernelContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages]
cta_group
comptime cta_group = config.cta_group
EPILOGUE_THREADS
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueCtx
comptime EpilogueCtx = EpilogueWarpContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]
InputTilePipeline
comptime InputTilePipeline = BlockScaledTilePipeline[a_type, b_type, sfa_dtype, sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfa_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfb_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_group_pipeline_stages, Int(config)]
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
MMA_THREADS
comptime MMA_THREADS = WARP_SIZE
MmaCtx
comptime MmaCtx = MmaWarpContext[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS]
MmaOp
comptime MmaOp = MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.block_tile_shape, config.mma_shape, accum_type=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = Int(config)
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = Int(config)
num_group_pipeline_stages
comptime num_group_pipeline_stages = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // Int(config))
num_output_stages
comptime num_output_stages = Int(config)
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = Int(config)
NUM_THREADS
comptime NUM_THREADS = (((BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
OutputPipeline
comptime OutputPipeline = OutputTilePipeline[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
Scheduler
comptime Scheduler = TileScheduler[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]
SCHEDULER_THREADS
comptime SCHEDULER_THREADS = WARP_SIZE
SFA_NUM_COLS
comptime SFA_NUM_COLS = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // 32)
sfa_smem_layout
comptime sfa_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, 32]()
SFATensor
comptime SFATensor = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TmemRegion.SFATile
SFB_NUM_COLS
comptime SFB_NUM_COLS = (BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N // 32)
sfb_smem_layout
comptime sfb_smem_layout = tile_sf_layout_k_major[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, 32]()
SFBTensor
comptime SFBTensor = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TmemRegion.SFBTile
SmemType
comptime SmemType = BlockScaledSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N
TileWriterType
comptime TileWriterType = BlockScaledTileWriter[a_type, b_type, sfa_dtype, sfb_dtype, transpose_b, config, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].c_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_output_stages, UInt(BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols), 4, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, register_based_epilogue=register_based_epilogue]
TMA_LOAD_THREADS
comptime TMA_LOAD_THREADS = WARP_SIZE
Tmem
comptime Tmem = TmemAllocation[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
TmemDeallocBarrier
comptime TmemDeallocBarrier = TmemDeallocBarrier[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
TmemRegion
comptime TmemRegion = BlockScaledTmem[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
Methods
validate_config
static validate_config()
Validate configuration constraints at compile time.
expected_bytes_per_k_group
static expected_bytes_per_k_group() -> Int
Calculate expected bytes for TMA barrier per k-group iteration.
Returns:
mma_block_scaled
static mma_block_scaled[tiles_origin: MutOrigin](accum: TmemTensor[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, BlockScaledTmem[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group].accum_layout, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group], tmem_region: BlockScaledTmem[BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages, sfa_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group], tiles: BlockScaledConsumerStage[tiles_origin, a_type, b_type, sfa_dtype, sfb_dtype, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].a_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].b_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfa_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].sfb_smem_layout, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages, BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_group_pipeline_stages, Int(config)], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, config.block_tile_shape, config.mma_shape, accum_type=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellBlockScaledMatmulKernel[a_type, b_type, c_type, sfa_dtype, sfb_dtype, a_layout, b_layout, c_layout, sfa_layout, sfb_layout, a_desc_layout, b_desc_layout, c_desc_layout, sfa_desc_layout, sfb_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], k_iter: UInt32, k_start: UInt32)
Copy scaling factors to TMEM and execute block-scaled MMA.
Args:
- accum (
TmemTensor): Typed TMEM tensor for accumulators. - tmem_region (
BlockScaledTmem): TMEM region with typed accessors for scaling factors. - tiles (
BlockScaledConsumerStage): Consumer stage with A, B, SFA, SFB tiles. - mma_op (
MmaOpSM100_BlockScaled_SS): Block-scaled MMA operation instance. - k_iter (
UInt32): Current K iteration index. - k_start (
UInt32): Starting K iteration (for init_c).
run
static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])
Kernel entry point. Dispatches to warp-specialized roles.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!