Mojo struct
BlackwellMatmulSM100FallbackKernel
struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]
Simple fallback matmul kernel for SM100 (B200).
This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.
Unlike the main BlackwellMatmulSM100Kernel, this uses:
- Single warp approach (no warp specialization)
- Basic barrier synchronization (no CLC scheduling)
- Direct LayoutTensor output (no TMA for C)
- Simpler pipeline with single buffer
Implemented traits
AnyType,
UnknownDestructibility
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_size
comptime a_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout.size()
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, a_swizzle]()
accum_type
comptime accum_type = get_accum_type[a_type]()
ATile
comptime ATile = LayoutTensor[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]
b_size
comptime b_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout.size()
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]()
BK
comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)
BTile
comptime BTile = LayoutTensor[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]
c_frag_size
comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // Int(num_threads))
max_tmem_cols
comptime max_tmem_cols = 512
MMA_K
comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)
num_k_mmas
comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)
num_m_mmas
comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)
num_n_mmas
comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)
Methods
validate_constraints
static validate_constraints()
Validate compile-time constraints for this kernel configuration.
run
static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: UInt)
Run the fallback matmul kernel.
Args:
- a_tma_op (
TMATensorTile): TMA descriptor for matrix A. - b_tma_op (
TMATensorTile): TMA descriptor for matrix B. - c (
LayoutTensor): Output tensor C (LayoutTensor, not TMA). - num_iters (
UInt): Number of K-dimension iterations.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!