Skip to main content

Mojo struct

BlackwellMatmulSM100FallbackKernel

struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple(1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = 128, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None]

Simple fallback matmul kernel for SM100 (B200).

This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.

Unlike the main BlackwellMatmulSM100Kernel, this uses:

  • Single warp approach (no warp specialization)
  • Basic barrier synchronization (no CLC scheduling)
  • Direct LayoutTensor output (no TMA for C)
  • Simpler pipeline with single buffer

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

a_size

comptime a_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout.size()

a_smem_layout

comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, a_swizzle]()

a_swizzle_elems

comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())

accum_type

comptime accum_type = get_accum_type[a_type]()

ADescLayout

comptime ADescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_swizzle_elems], ComptimeInt[1]]

ATile

comptime ATile = LayoutTensor[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]

ATileLayout

comptime ATileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]

ATmaOp

comptime ATmaOp = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATmaTile.InnerType

ATmaTile

comptime ATmaTile = TMATile[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]

b_size

comptime b_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout.size()

b_smem_layout

comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]()

b_swizzle_elems

comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())

BDescLayout

comptime BDescLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_swizzle_elems], ComptimeInt[1]]

BK

comptime BK = block_tile_shape.__getitem__[Int](2)

BM

comptime BM = block_tile_shape.__getitem__[Int](0)

BN

comptime BN = block_tile_shape.__getitem__[Int](1)

BTile

comptime BTile = LayoutTensor[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]

BTileLayout

comptime BTileLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK], ComptimeInt[1]]

BTmaOp

comptime BTmaOp = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTmaTile.InnerType

BTmaTile

comptime BTmaTile = TMATile[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]

c_frag_size

comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // num_threads)

CGmemStrideLayout

comptime CGmemStrideLayout = Layout[ComptimeInt[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].static_N], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]

max_tmem_cols

comptime max_tmem_cols = 512

MMA_K

comptime MMA_K = mma_shape.__getitem__[Int](2)

MMA_M

comptime MMA_M = mma_shape.__getitem__[Int](0)

MMA_N

comptime MMA_N = mma_shape.__getitem__[Int](1)

num_k_mmas

comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)

num_m_mmas

comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)

num_n_mmas

comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)

static_N

comptime static_N = c_layout.static_stride[0]

Methods

validate_constraints

static validate_constraints()

Validate compile-time constraints for this kernel configuration.

run

static run(a_tma_op: TMATensorTile[a_type, _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ATileLayout](), _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].ADescLayout]()], b_tma_op: TMATensorTile[b_type, _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BTileLayout](), _to_legacy_layout[BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, c_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BDescLayout]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Scalar[DType.uint])

Run the fallback matmul kernel.

Args:

  • a_tma_op (TMATensorTile): TMA descriptor for matrix A.
  • b_tma_op (TMATensorTile): TMA descriptor for matrix B.
  • c (TileTensor): Output tensor C (TileTensor, direct global memory writes).
  • num_iters (Scalar): Number of K-dimension iterations.

Was this page helpful?