IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlackwellMatmulSM100FallbackKernel

struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[Int(3)], mma_shape: IndexList[Int(3)], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, Int(3)] = StaticTuple(Int32(1), Int32(1), Int32(1)), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = Int(128), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None]

Simple fallback matmul kernel for SM100 (B200).

This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.

Unlike the main BlackwellMatmulSM100Kernel, this uses:

  • Single warp approach (no warp specialization)
  • Basic barrier synchronization (no CLC scheduling)
  • Direct TileTensor output (no TMA for C)
  • Simpler pipeline with single buffer

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

a_size​

comptime a_size = (block_tile_shape[Int(0)] * block_tile_shape[Int(2)])

a_smem_layout_typed​

comptime a_smem_layout_typed = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))

a_swizzle_elems​

comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())

accum_type​

comptime accum_type = get_accum_type[a_type]()

ADescLayout​

comptime ADescLayout = Layout[*?, *?]

ATile​

comptime ATile = TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

ATileLayout​

comptime ATileLayout = Layout[*?, *?]

ATmaOp​

comptime ATmaOp = TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

b_size​

comptime b_size = (block_tile_shape[Int(1)] * block_tile_shape[Int(2)])

b_smem_layout_typed​

comptime b_smem_layout_typed = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))

b_swizzle_elems​

comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())

BDescLayout​

comptime BDescLayout = Layout[*?, *?]

BK​

comptime BK = block_tile_shape[Int(2)]

BM​

comptime BM = block_tile_shape[Int(0)]

BN​

comptime BN = block_tile_shape[Int(1)]

BTile​

comptime BTile = TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

BTileLayout​

comptime BTileLayout = Layout[*?, *?]

BTmaOp​

comptime BTmaOp = TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]

c_frag_size​

comptime c_frag_size = (Int((mul mma_shape[Int(0)], mma_shape[Int(1)])) // num_threads)

CGmemStrideLayout​

comptime CGmemStrideLayout = Layout[*?, *?]

max_tmem_cols​

comptime max_tmem_cols = 512

MMA_K​

comptime MMA_K = mma_shape[Int(2)]

MMA_M​

comptime MMA_M = mma_shape[Int(0)]

MMA_N​

comptime MMA_N = mma_shape[Int(1)]

num_k_mmas​

comptime num_k_mmas = (block_tile_shape[Int(2)] // mma_shape[Int(2)])

num_m_mmas​

comptime num_m_mmas = (block_tile_shape[Int(0)] // mma_shape[Int(0)])

num_n_mmas​

comptime num_n_mmas = (block_tile_shape[Int(1)] // mma_shape[Int(1)])

static_N​

comptime static_N = c_layout.static_stride[Int(0)]

Methods​

validate_constraints​

static def validate_constraints()

Validate compile-time constraints for this kernel configuration.

run​

static def run(a_tma_op: TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)

Run the fallback matmul kernel.

Args: