Skip to main content

Mojo struct

BlackwellMatmulSM100FallbackKernel

struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]

Simple fallback matmul kernel for SM100 (B200).

This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.

Unlike the main BlackwellMatmulSM100Kernel, this uses:

  • Single warp approach (no warp specialization)
  • Basic barrier synchronization (no CLC scheduling)
  • Direct LayoutTensor output (no TMA for C)
  • Simpler pipeline with single buffer

Implemented traits

AnyType, UnknownDestructibility

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_size

comptime a_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout.size()

a_smem_layout

comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, a_swizzle]()

accum_type

comptime accum_type = get_accum_type[a_type]()

ATile

comptime ATile = LayoutTensor[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]

b_size

comptime b_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout.size()

b_smem_layout

comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]()

BK

comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)

BTile

comptime BTile = LayoutTensor[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]

c_frag_size

comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // Int(num_threads))

max_tmem_cols

comptime max_tmem_cols = 512

MMA_K

comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_M

comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)

num_k_mmas

comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)

num_m_mmas

comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)

num_n_mmas

comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)

Methods

validate_constraints

static validate_constraints()

Validate compile-time constraints for this kernel configuration.

run

static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: UInt)

Run the fallback matmul kernel.

Args:

  • a_tma_op (TMATensorTile): TMA descriptor for matrix A.
  • b_tma_op (TMATensorTile): TMA descriptor for matrix B.
  • c (LayoutTensor): Output tensor C (LayoutTensor, not TMA).
  • num_iters (UInt): Number of K-dimension iterations.

Was this page helpful?