Skip to main content

Mojo struct

BlackwellMatmulSM100FallbackKernel

struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1), SIMD(1), SIMD(1)), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None]

Simple fallback matmul kernel for SM100 (B200).

This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.

Unlike the main BlackwellMatmulSM100Kernel, this uses:

  • Single warp approach (no warp specialization)
  • Basic barrier synchronization (no CLC scheduling)
  • Direct TileTensor output (no TMA for C)
  • Simpler pipeline with single buffer

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_size

comptime a_size = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK)

a_smem_layout_typed

comptime a_smem_layout_typed = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // 8)]()), Coord(Idx[(a_swizzle.bytes() // size_of[a_type]())](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (a_swizzle.bytes() // size_of[a_type]()))]())), Coord(Coord(Idx[(a_swizzle.bytes() // size_of[a_type]())](), Idx[(8 * (a_swizzle.bytes() // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * (a_swizzle.bytes() // size_of[a_type]()))]())))

a_swizzle_elems

comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())

accum_type

comptime accum_type = get_accum_type[a_type]()

ADescLayout

comptime ADescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[1]]

ATile

comptime ATile = TileTensor[a_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // 8)]], Coord[ComptimeInt[(a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (a_swizzle.bytes() // size_of[a_type]()))]], Coord[ComptimeInt[(a_swizzle.bytes() // size_of[a_type]())], ComptimeInt[(8 * (a_swizzle.bytes() // size_of[a_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (a_swizzle.bytes() // size_of[a_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM * (a_swizzle.bytes() // size_of[a_type]()))]]], MutAnyOrigin, address_space=AddressSpace.SHARED]

ATileLayout

comptime ATileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]

ATmaOp

comptime ATmaOp = TMATensorTile[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]()]

b_size

comptime b_size = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK)

b_smem_layout_typed

comptime b_smem_layout_typed = Layout(Coord(Coord(Idx[8](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // 8)]()), Coord(Idx[(b_swizzle.bytes() // size_of[b_type]())](), Idx[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (b_swizzle.bytes() // size_of[b_type]()))]())), Coord(Coord(Idx[(b_swizzle.bytes() // size_of[b_type]())](), Idx[(8 * (b_swizzle.bytes() // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * (b_swizzle.bytes() // size_of[b_type]()))]())))

b_swizzle_elems

comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())

BDescLayout

comptime BDescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[1]]

BK

comptime BK = block_tile_shape[2]

BM

comptime BM = block_tile_shape[0]

BN

comptime BN = block_tile_shape[1]

BTile

comptime BTile = TileTensor[b_type, Layout[Coord[ComptimeInt[8], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // 8)]], Coord[ComptimeInt[(b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[(BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // (b_swizzle.bytes() // size_of[b_type]()))]], Coord[ComptimeInt[(b_swizzle.bytes() // size_of[b_type]())], ComptimeInt[(8 * (b_swizzle.bytes() // size_of[b_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK == (b_swizzle.bytes() // size_of[b_type]())) else (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN * (b_swizzle.bytes() // size_of[b_type]()))]]], MutAnyOrigin, address_space=AddressSpace.SHARED]

BTileLayout

comptime BTileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]

BTmaOp

comptime BTmaOp = TMATensorTile[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]()]

c_frag_size

comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // num_threads)

CGmemStrideLayout

comptime CGmemStrideLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].static_N], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

max_tmem_cols

comptime max_tmem_cols = 512

MMA_K

comptime MMA_K = mma_shape[2]

MMA_M

comptime MMA_M = mma_shape[0]

MMA_N

comptime MMA_N = mma_shape[1]

num_k_mmas

comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)

num_m_mmas

comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)

num_n_mmas

comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)

static_N

comptime static_N = c_layout.static_stride[0]

Methods

validate_constraints

static validate_constraints()

Validate compile-time constraints for this kernel configuration.

run

static run(a_tma_op: TMATensorTile[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout](), _to_index_list[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout.rank, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)

Run the fallback matmul kernel.

Args:

  • a_tma_op (TMATensorTile): TMA descriptor for matrix A.
  • b_tma_op (TMATensorTile): TMA descriptor for matrix B.
  • c (TileTensor): Output tensor C (TileTensor, direct global memory writes).
  • num_iters (Int): Number of K-dimension iterations.

Was this page helpful?