Mojo struct
BlackwellMatmulSM100FallbackKernel
struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1), SIMD(1), SIMD(1)), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None]
Simple fallback matmul kernel for SM100 (B200).
This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.
Unlike the main BlackwellMatmulSM100Kernel, this uses:
- Single warp approach (no warp specialization)
- Basic barrier synchronization (no CLC scheduling)
- Direct TileTensor output (no TMA for C)
- Simpler pipeline with single buffer
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
a_size
comptime a_size = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK)
a_smem_layout_typed
comptime a_smem_layout_typed = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // 8)]()), Coord(Idx[(a_swizzle.bytes() // size_of[a_type]())](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * (a_swizzle.bytes() // size_of[a_type]()))]())))
a_swizzle_elems
comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())
accum_type
comptime accum_type = get_accum_type[a_type]()
ADescLayout
comptime ADescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[1]]
ATile
comptime ATile = TileTensor[a_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // 8)]], Coord[ComptimeInt[(a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (a_swizzle.bytes() // size_of[a_type]()))]], Coord[ComptimeInt[(a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[(8 * (a_swizzle.bytes() // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * (a_swizzle.bytes() // size_of[a_type]()))]]], MutAnyOrigin, address_space=AddressSpace.SHARED]
ATileLayout
comptime ATileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]
ATmaOp
comptime ATmaOp = TMATensorTile[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]()]
b_size
comptime b_size = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK)
b_smem_layout_typed
comptime b_smem_layout_typed = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // 8)]()), Coord(Idx[(b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * (b_swizzle.bytes() // size_of[b_type]()))]())))
b_swizzle_elems
comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())
BDescLayout
comptime BDescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[1]]
BK
comptime BK = block_tile_shape[2]
BM
comptime BM = block_tile_shape[0]
BN
comptime BN = block_tile_shape[1]
BTile
comptime BTile = TileTensor[b_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // 8)]], Coord[ComptimeInt[(b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (b_swizzle.bytes() // size_of[b_type]()))]], Coord[ComptimeInt[(b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[(8 * (b_swizzle.bytes() // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * (b_swizzle.bytes() // size_of[b_type]()))]]], MutAnyOrigin, address_space=AddressSpace.SHARED]
BTileLayout
comptime BTileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]
BTmaOp
comptime BTmaOp = TMATensorTile[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]()]
c_frag_size
comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // num_threads)
CGmemStrideLayout
comptime CGmemStrideLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].static_N], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
max_tmem_cols
comptime max_tmem_cols = 512
MMA_K
comptime MMA_K = mma_shape[2]
MMA_M
comptime MMA_M = mma_shape[0]
MMA_N
comptime MMA_N = mma_shape[1]
num_k_mmas
comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)
num_m_mmas
comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)
num_n_mmas
comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)
static_N
comptime static_N = c_layout.static_stride[0]
Methods
validate_constraints
static validate_constraints()
Validate compile-time constraints for this kernel configuration.
run
static run(a_tma_op: TMATensorTile[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)
Run the fallback matmul kernel.
Args:
- a_tma_op (
TMATensorTile): TMA descriptor for matrix A. - b_tma_op (
TMATensorTile): TMA descriptor for matrix B. - c (
TileTensor): Output tensor C (TileTensor, direct global memory writes). - num_iters (
Int): Number of K-dimension iterations.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!