Mojo struct
BlackwellMatmulSM100FallbackKernel
struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(Int32(1), Int32(1), Int32(1)), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = 128, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]
Simple fallback matmul kernel for SM100 (B200).
This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.
Unlike the main BlackwellMatmulSM100Kernel, this uses:
- Single warp approach (no warp specialization)
- Basic barrier synchronization (no CLC scheduling)
- Direct TileTensor output (no TMA for C)
- Simpler pipeline with single buffer
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
a_sizeβ
comptime a_size = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK)
a_smem_layout_typedβ
comptime a_smem_layout_typed = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // 8)]()), Coord(Idx[(a_swizzle.bytes() // size_of[a_type]())](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * (a_swizzle.bytes() // size_of[a_type]()))]())))
a_swizzle_elemsβ
comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())
accum_typeβ
comptime accum_type = get_accum_type[a_type]()
ADescLayoutβ
comptime ADescLayout = Layout[*?, *?]
ATileβ
comptime ATile = TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
ATileLayoutβ
comptime ATileLayout = Layout[*?, *?]
ATmaOpβ
comptime ATmaOp = TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
b_sizeβ
comptime b_size = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK)
b_smem_layout_typedβ
comptime b_smem_layout_typed = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // 8)]()), Coord(Idx[(b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * (b_swizzle.bytes() // size_of[b_type]()))]())))
b_swizzle_elemsβ
comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())
BDescLayoutβ
comptime BDescLayout = Layout[*?, *?]
BKβ
comptime BK = block_tile_shape[2]
BMβ
comptime BM = block_tile_shape[0]
BNβ
comptime BN = block_tile_shape[1]
BTileβ
comptime BTile = TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
BTileLayoutβ
comptime BTileLayout = Layout[*?, *?]
BTmaOpβ
comptime BTmaOp = TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]
c_frag_sizeβ
comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // num_threads)
CGmemStrideLayoutβ
comptime CGmemStrideLayout = Layout[*?, *?]
max_tmem_colsβ
comptime max_tmem_cols = 512
MMA_Kβ
comptime MMA_K = mma_shape[2]
MMA_Mβ
comptime MMA_M = mma_shape[0]
MMA_Nβ
comptime MMA_N = mma_shape[1]
num_k_mmasβ
comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)
num_m_mmasβ
comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)
num_n_mmasβ
comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)
static_Nβ
comptime static_N = c_layout.static_stride[0]
Methodsβ
validate_constraintsβ
static validate_constraints()
Validate compile-time constraints for this kernel configuration.
runβ
static run(a_tma_op: TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)
Run the fallback matmul kernel.
Args:
- βa_tma_op (
TMATensorTile[a_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for matrix A. - βb_tma_op (
TMATensorTile[b_type, Layout[*?, *?].rank, _to_index_list[Layout[*?, *?]](), _to_index_list[Layout[*?, *?].rank, Layout[*?, *?]]()]): TMA descriptor for matrix B. - βc (
TileTensor[c_type, c_layout, MutAnyOrigin]): Output tensor C (TileTensor, direct global memory writes). - βnum_iters (
Int): Number of K-dimension iterations.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!