Mojo struct
BlackwellMatmulSM100Kernel
struct BlackwellMatmulSM100Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]
Blackwell SM100 GEMM kernel with warp specialization.
This struct unifies all parameters and derived types for the SM100 matmul kernel, providing:
- Compile-time parameter validation
- Centralized derived type computation
- Factory methods for kernel components
- Multiple kernel entry points (standard, split-k)
The SM100 kernel uses:
- Tensor Memory (TMEM) for MMA accumulators
- Cluster Launch Control (CLC) for dynamic tile scheduling
- Warp specialization: Scheduler, TMA Load, MMA, Epilogue warps
- Software pipelining for overlapping compute and memory operations
Implemented traits
AnyType,
UnknownDestructibility
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.a_swizzle]()
accum_pipeline_consumer_arv_count
comptime accum_pipeline_consumer_arv_count = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
accum_pipeline_producer_arv_count
comptime accum_pipeline_producer_arv_count = 1
accum_type
comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]()
BK
comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)
clc_consumer_arv_count
comptime clc_consumer_arv_count = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE * ((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)))
clc_producer_arv_count
comptime clc_producer_arv_count = 1
clc_throttle_consumer_arv_count
comptime clc_throttle_consumer_arv_count = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS
clc_throttle_producer_arv_count
comptime clc_throttle_producer_arv_count = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS
CLUSTER_M
comptime CLUSTER_M = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](0))
CLUSTER_N
comptime CLUSTER_N = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](1))
CLUSTER_SIZE
comptime CLUSTER_SIZE = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)
Context
comptime Context = KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]
cta_group
comptime cta_group = config.cta_group
EPILOGUE_THREADS
comptime EPILOGUE_THREADS = (4 * WARP_SIZE)
EpilogueConf
comptime EpilogueConf = EpilogueConfig[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.c_smem_layout.shape[1].value(), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, False]
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
MMA_THREADS
comptime MMA_THREADS = WARP_SIZE
MmaOp
comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, accum_type=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]
num_accum_pipeline_stages
comptime num_accum_pipeline_stages = Int(config)
num_clc_pipeline_stages
comptime num_clc_pipeline_stages = Int(config)
num_group_pipeline_stages
comptime num_group_pipeline_stages = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // Int(config))
num_k_mmas
comptime num_k_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_K)
num_m_mmas
comptime num_m_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group))
num_n_mmas
comptime num_n_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group))
num_output_stages
comptime num_output_stages = Int(config)
num_output_warps
comptime num_output_warps = 4
num_pipeline_stages
comptime num_pipeline_stages = Int(config)
NUM_THREADS
comptime NUM_THREADS = (((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)
NUM_TMEM_COLS
comptime NUM_TMEM_COLS = 512
OutputM
comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)
OutputN
comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)
OutputRB
comptime OutputRB = OutputRingBuffer[Int(config), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]
RingBuffer
comptime RingBuffer = RingBuffer[a_type, b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, Int(config)]
Scheduler
comptime Scheduler = TileScheduler[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]
SCHEDULER_THREADS
comptime SCHEDULER_THREADS = WARP_SIZE
SmemType
comptime SmemType = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config]
stage_stride_cols
comptime stage_stride_cols = (512 // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages)
TileLoaderTMA
comptime TileLoaderTMA = TileLoaderTMA[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, Int(config), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages]
TMA_LOAD_THREADS
comptime TMA_LOAD_THREADS = WARP_SIZE
Methods
validate_constraints
static validate_constraints()
Validate parameter constraints at compile time.
init_barriers
static init_barriers(ctx: KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N], a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], tma_mma_mbars_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], accum_mbars_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tmem_dealloc_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])
Initialize barriers and prefetch TMA descriptors. Called by elect_one_warp && elect_one_thread.
mma
static mma(tmem_addr: UInt32, tiles: ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)
Execute MMA operations for one pipeline stage.
This is the core MMA function designed to be called within a consumer tiles context:
with consumer.get_tiles() as tiles:
Self.mma(tmem_addr, tiles, mma_op, ...)Args:
- tmem_addr (
UInt32): Tensor memory address for accumulators. - tiles (
ConsumerTiles): ConsumerTiles context with stage, mbar, and tile arrays. - mma_op (
MmaOpSM100_SS): The MMA operation instance. - elect_one_warp (
Bool): Whether this warp should execute. - iter_idx (
UInt32): K iteration index. - k_start (
UInt32): Starting K iteration (for init_c determination).
run
static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])
Main kernel entry point for SM100 matrix multiplication.
run_splitk
static run_splitk[reduction_layout: Layout](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], reduction_tensor: LayoutTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: LegacyUnsafePointer[UInt8], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])
Split-K kernel entry point for better parallelism on small problems.
Split-K divides the K dimension across multiple CTAs, with each CTA computing a partial result that is then reduced.
Args:
- a_tma_op (
TMATensorTile): TMA descriptor for matrix A. - b_tma_op (
TMATensorTile): TMA descriptor for matrix B. - c_tma_op (
TMATensorTile): TMA descriptor for matrix C. - reduction_tensor (
LayoutTensor): Workspace for partial results from each split. - lock_ptr (
LegacyUnsafePointer): Synchronization locks for reduction coordination. - cluster_dim (
StaticTuple): Cluster dimensions. - mnk (
StaticTuple): Problem dimensions (M, N, K). - workspace (
Span): Workspace buffer for profiling/scheduling.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!