For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
BlackwellMatmulSM100FallbackKernel
struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, c_layout: TensorLayout, block_tile_shape: IndexList[Int(3)], mma_shape: IndexList[Int(3)], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, Int(3)] = StaticTuple(Int32(1), Int32(1), Int32(1)), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: Int = Int(128), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None]
Simple fallback matmul kernel for SM100 (B200).
This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations.
Unlike the main BlackwellMatmulSM100Kernel, this uses:
- Single warp approach (no warp specialization)
- Basic barrier synchronization (no CLC scheduling)
- Direct TileTensor output (no TMA for C)
- Simpler pipeline with single buffer
Implemented traitsβ
comptime membersβ
a_sizeβ
comptime a_size = (block_tile_shape[Int(0)] * block_tile_shape[Int(2)])
a_smem_layout_typedβ
comptime a_smem_layout_typed = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))
a_swizzle_elemsβ
comptime a_swizzle_elems = (a_swizzle.bytes() // size_of[a_type]())
accum_typeβ
comptime accum_type = get_accum_type[a_type]()
ADescLayoutβ
comptime ADescLayout = Layout[*?, *?]
ATileβ
comptime ATile = TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
ATileLayoutβ
comptime ATileLayout = Layout[*?, *?]
ATmaOpβ
comptime ATmaOp = TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]
b_sizeβ
comptime b_size = (block_tile_shape[Int(1)] * block_tile_shape[Int(2)])
b_smem_layout_typedβ
comptime b_smem_layout_typed = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))
b_swizzle_elemsβ
comptime b_swizzle_elems = (b_swizzle.bytes() // size_of[b_type]())
BDescLayoutβ
comptime BDescLayout = Layout[*?, *?]
BKβ
comptime BK = block_tile_shape[Int(2)]
BMβ
comptime BM = block_tile_shape[Int(0)]
BNβ
comptime BN = block_tile_shape[Int(1)]
BTileβ
comptime BTile = TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
BTileLayoutβ
comptime BTileLayout = Layout[*?, *?]
BTmaOpβ
comptime BTmaOp = TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]
c_frag_sizeβ
comptime c_frag_size = (Int((mul mma_shape[Int(0)], mma_shape[Int(1)])) // num_threads)
CGmemStrideLayoutβ
comptime CGmemStrideLayout = Layout[*?, *?]
max_tmem_colsβ
comptime max_tmem_cols = 512
MMA_Kβ
comptime MMA_K = mma_shape[Int(2)]
MMA_Mβ
comptime MMA_M = mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = mma_shape[Int(1)]
num_k_mmasβ
comptime num_k_mmas = (block_tile_shape[Int(2)] // mma_shape[Int(2)])
num_m_mmasβ
comptime num_m_mmas = (block_tile_shape[Int(0)] // mma_shape[Int(0)])
num_n_mmasβ
comptime num_n_mmas = (block_tile_shape[Int(1)] // mma_shape[Int(1)])
static_Nβ
comptime static_N = c_layout.static_stride[Int(0)]
Methodsβ
validate_constraintsβ
static def validate_constraints()
Validate compile-time constraints for this kernel configuration.
runβ
static def run(a_tma_op: TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], b_tma_op: TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()], c: TileTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)
Run the fallback matmul kernel.
Args:
- βa_tma_op (
TMATensorTile[a_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for matrix A. - βb_tma_op (
TMATensorTile[b_type, Int(2), _to_index_list[Layout[*?, *?]](), _to_index_list[Int(2), Layout[*?, *?]]()]): TMA descriptor for matrix B. - βc (
TileTensor[c_type, c_layout, MutAnyOrigin]): Output tensor C (TileTensor, direct global memory writes). - βnum_iters (
Int): Number of K-dimension iterations.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!