Mojo struct
BlackwellMatmulSM100FallbackKernel
struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None]
Simple fallback matmul kernel for SM100 (B200).
This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.
Unlike the main BlackwellMatmulSM100Kernel, this uses:
- Single warp approach (no warp specialization)
- Basic barrier synchronization (no CLC scheduling)
- Direct LayoutTensor output (no TMA for C)
- Simpler pipeline with single buffer
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
a_size
comptime a_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout.size()
a_smem_layout
comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, a_swizzle]()
a_swizzle_elems
comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())
accum_type
comptime accum_type = get_accum_type[a_type]()
ADescLayout
comptime ADescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[1]]
ATile
comptime ATile = LayoutTensor[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]
ATileLayout
comptime ATileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]
ATmaOp
comptime ATmaOp = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATmaTile.InnerType
ATmaTile
comptime ATmaTile = TMATile[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]
b_size
comptime b_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout.size()
b_smem_layout
comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]()
b_swizzle_elems
comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())
BDescLayout
comptime BDescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[1]]
BK
comptime BK = block_tile_shape.__getitem__[Int](2)
BM
comptime BM = block_tile_shape.__getitem__[Int](0)
BN
comptime BN = block_tile_shape.__getitem__[Int](1)
BTile
comptime BTile = LayoutTensor[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]
BTileLayout
comptime BTileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]
BTmaOp
comptime BTmaOp = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTmaTile.InnerType
BTmaTile
comptime BTmaTile = TMATile[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]
c_frag_size
comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // num_threads)
CGmemStrideLayout
comptime CGmemStrideLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].static_N], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
max_tmem_cols
comptime max_tmem_cols = 512
MMA_K
comptime MMA_K = mma_shape.__getitem__[Int](2)
MMA_M
comptime MMA_M = mma_shape.__getitem__[Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[Int](1)
num_k_mmas
comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)
num_m_mmas
comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)
num_n_mmas
comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)
static_N
comptime static_N = c_layout.static_stride[0]
Methods
validate_constraints
static validate_constraints()
Validate compile-time constraints for this kernel configuration.
run
static run(a_tma_op: TMATensorTile[a_type, _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout](), _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]()], b_tma_op: TMATensorTile[b_type, _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout](), _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Scalar[DType.uint])
Run the fallback matmul kernel.
Args:
- a_tma_op (
TMATensorTile): TMA descriptor for matrix A. - b_tma_op (
TMATensorTile): TMA descriptor for matrix B. - c (
TileTensor): Output tensor C (TileTensor, direct global memory writes). - num_iters (
Scalar): Number of K-dimension iterations.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!