Skip to main content

Mojo struct

HopperMatmulSM90Kernel

struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_smem_layout: Layout, block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Int, num_threads: Int = 128, transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = 3, b_swizzle: TensorMapSwizzle = 3, c_swizzle: TensorMapSwizzle = 0, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = 1, pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, hilbert_swizzle: Bool = False]

Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs.

This kernel implements a highly optimized matrix multiplication (GEMM) using:

  • Tensor Memory Accelerator (TMA) for efficient global-to-shared memory transfers
  • Warp Group Matrix Multiply Accumulate (WGMMA) instructions for tensor cores
  • Multi-stage software pipelining for overlapping compute and memory operations
  • Producer-consumer model with separate warp groups for loading and computing

Template Parameters: a_type, b_type, c_type: Data types for input and output matrices a_layout, b_layout, c_layout: Memory layouts for matrices c_smem_layout: Shared memory layout for output tile block_tile_shape: Tile dimensions [M, N, K] processed by each thread block wgmma_shape: Dimensions for each WGMMA instruction [M, N, K] cluster_shape: Thread block cluster dimensions for distributed shared memory num_pipeline_stages: Number of stages in the software pipeline (typically 3-7) num_threads: Number of threads per block (must be multiple of 128) transpose_b: Whether B matrix is transposed (required to be True) a_swizzle, b_swizzle: Memory swizzling for bank-conflict-free access c_swizzle: Swizzling for output writes partitioned_multicast: Enable partitioned multicast for large tiles use_tma_store: Use TMA for storing output (vs regular stores) promotion_frequency: How often to promote FP8 accumulation to higher precision pdl_level: Programmatic Dependency Launch (PDL) level elementwise_lambda_fn: Optional epilogue function elementwise_compute_lambda_fn: Optional compute function hilbert_swizzle: Use Hilbert curve for thread block scheduling

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = True

a_smem_layout

alias a_smem_layout = tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle]()

accum_type

alias accum_type = get_accum_type[a_type]()

AccumRegTileType

alias AccumRegTileType = LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)]

b_smem_layout

alias b_smem_layout = tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle]()

BK

alias BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

alias BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

alias BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)

c_frag_size

alias c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)

cluster_size

alias cluster_size = Int.__init__[Int32](((cluster_shape.__getitem__[3, Int](0) * cluster_shape.__getitem__[3, Int](1)) * cluster_shape.__getitem__[3, Int](2)))

num_consumer

alias num_consumer = ((num_threads // 128) - 1)

num_consumer_threads

alias num_consumer_threads = (((num_threads // 128) - 1) * 128)

num_m_mmas

alias num_m_mmas = ((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1))

num_n_mmas

alias num_n_mmas = (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))

RingBuffer

alias RingBuffer[tma_transfer: Bool = True] = RingBuffer[a_type, b_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), num_pipeline_stages, ((num_threads // 128) - 1), Int.__init__[Int32](((cluster_shape.__getitem__[3, Int](0) * cluster_shape.__getitem__[3, Int](1)) * cluster_shape.__getitem__[3, Int](2))), tma_transfer]

Parameters

  • tma_transfer (Bool):

RingBufferConsumer

alias RingBufferConsumer[origin: MutableOrigin, tma_transfer: Bool] = RingBufferConsumer[origin, RingBuffer[a_type, b_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), num_pipeline_stages, ((num_threads // 128) - 1), Int.__init__[Int32](((cluster_shape.__getitem__[3, Int](0) * cluster_shape.__getitem__[3, Int](1)) * cluster_shape.__getitem__[3, Int](2))), tma_transfer]]

Parameters

SMem

alias SMem = HopperMatmulSM90Kernel_SMem[a_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), b_type, tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), c_type, c_smem_layout, num_pipeline_stages]

WgmmaOp

alias WgmmaOp = TensorCoreAsync[get_accum_type[a_type](), a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]

Methods

validate_constraints

static validate_constraints()

Validate common constraints for all kernel variants.

pipeline_init

static pipeline_init()

Initialize pipeline synchronization barriers.

This function ensures that all pipeline initialization (barriers, shared memory) is visible to all thread blocks in the cluster before proceeding. This is critical for correct producer-consumer synchronization.

For multi-cluster configurations, uses fence and cluster sync. For single block, uses a simple barrier.

finalize_kernel

static finalize_kernel()

Common finalization for all kernel variants.

multicast_mask

static multicast_mask(rank_m: UInt, rank_n: UInt) -> Tuple[Int32, Int32]

Returns:

Tuple

common_kernel_init

static common_kernel_init() -> Tuple[UInt, UInt, UInt, UInt, UInt, Bool]

Common initialization for all kernel variants.

Returns:

Tuple: Tuple of (warp_group_idx, warp_group_thread_idx, rank_m, rank_n, warp_id, lane_predicate).

build_ring_buffer

static build_ring_buffer[tma_transfer: Bool = True](smem: HopperMatmulSM90Kernel_SMem[a_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), b_type, tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), c_type, c_smem_layout, num_pipeline_stages], warp_group_thread_idx: UInt) -> RingBuffer[a_type, b_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), num_pipeline_stages, ((num_threads // 128) - 1), Int.__init__[Int32](((cluster_shape.__getitem__[3, Int](0) * cluster_shape.__getitem__[3, Int](1)) * cluster_shape.__getitem__[3, Int](2))), tma_transfer]

Create ring buffer for producer-consumer synchronization.

Returns:

RingBuffer

setup_producer

static setup_producer() -> Int

Setup producer warp group by deallocating registers.

Returns:

Int: Number of registers deallocated.

setup_consumer

static setup_consumer(warp_group_idx: UInt) -> Tuple[UInt, LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)], LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)]]

Setup consumer warp group.

Returns:

Tuple: Tuple of (local_warp_group_idx, c_reg_tile, final_c_reg_tile).

get_block_swizzle

static get_block_swizzle(lut_ptr: UnsafePointer[UInt32] = UnsafePointer[UInt32, AddressSpace(0), True, MutableAnyOrigin]()) -> IndexList[2, element_type=DType.uint32]

Calculate block swizzle for better L2 cache locality.

Args:

  • lut_ptr (UnsafePointer): Lookup table for Hilbert curve block scheduling (optional).

Returns:

IndexList: Swizzled block indices.

consumer_output

static consumer_output[custom_elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = elementwise_lambda_fn](c_tma_op: TMATensorTile[c_type, layout, desc_layout], c: LayoutTensor[c_type, layout, MutableAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tile: LayoutTensor[c_type, c_smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], output_reg_tile: LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)], warp_group_thread_idx: UInt, local_warp_group_idx: UInt, local_thread_idx: UInt, block_y: Int, block_x: Int)

Handle consumer output by writing GEMM results to global memory.

build_tma_loaders

static build_tma_loaders[a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, //](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], rank_m: UInt, rank_n: UInt) -> Tuple[TileLoaderTMA[a_tma_op, a_type, a_tile_layout, a_desc_layout, BK=block_tile_shape.__getitem__[3, DType.int64, Int](2), cluster_size=cluster_shape.__getitem__[3, Int](0), use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[b_tma_op, b_type, b_tile_layout, b_desc_layout, BK=block_tile_shape.__getitem__[3, DType.int64, Int](2), cluster_size=cluster_shape.__getitem__[3, Int](1), use_partitioned_multicast=partitioned_multicast]]

Returns:

Tuple

build_cpasync_loaders

static build_cpasync_loaders[k_align: Int, vector_size: Int = (k_align // size_of[a_type]()), num_threads_per_row: Int = (block_tile_shape.__getitem__[3, DType.int64, Int](2) // vector_size), thread_layout: Layout = Layout.row_major((_resolve_warpgroup_size() // num_threads_per_row), num_threads_per_row)](a: LayoutTensor[a_type, a_layout, MutableAnyOrigin], b: LayoutTensor[b_type, b_layout, MutableAnyOrigin]) -> Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]

Returns:

Tuple

producer_main_loop

static producer_main_loop[a_loader_type: TileLoader, b_loader_type: TileLoader, //, num_k_iters: Int](m_coord: UInt, n_coord: UInt, k_coord: UInt, a_loader: a_loader_type, b_loader: b_loader_type, mut ring_buffer: RingBuffer[a_loader_type._dtype, b_loader_type._dtype, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer])

Polymorphic A and B Tile Loader, works with both TMA and CPAsync.

run

static run[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], a: LayoutTensor[a_type, a_layout, MutableAnyOrigin], b: LayoutTensor[b_type, b_layout, MutableAnyOrigin], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin], lut_ptr: UnsafePointer[UInt32])

Main kernel entry point for matrix multiplication.

This kernel implements a producer-consumer pattern where:

  • One warp group (producer) loads tiles from global memory using TMA
  • Multiple warp groups (consumers) perform matrix multiplication using tensor cores

The kernel uses software pipelining to overlap memory transfers with computation, achieving high throughput on Hopper GPUs.

Args:

run_splitk

static run_splitk[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, splits: Int, raster_order: RasterOrder](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin], workspace_buffer: NDBuffer[get_accum_type[a_type](), 3, MutableAnyOrigin], locks_ptr: UnsafePointer[UInt8], problem_shape: IndexList[3])

Split-K variant of the kernel for better load balancing on small problems.

run_grouped

static run_grouped[a_tile_layout: Layout, b_tile_layout: Layout, c_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tile_layout, c_desc_layout], a_offsets: NDBuffer[DType.uint32, 1, MutableAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutableAnyOrigin], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin])

Grouped matmul variant for MoE (Mixture of Experts) models.

This variant handles multiple experts where each expert processes a subset of tokens. The a_offsets array indicates token boundaries for each expert.

consumer_main_loop

static consumer_main_loop[ring_buffer_origin: MutableOrigin, //, num_k_iters: Int](wgmma_op: TensorCoreAsync[get_accum_type[a_type](), a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: UInt, final_c_reg_tile: LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)], c_reg_tile: LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)], mut ring_buffer: RingBufferConsumer[ring_buffer_origin, RingBuffer[a_type, b_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), num_pipeline_stages, ((num_threads // 128) - 1), Int.__init__[Int32](((cluster_shape.__getitem__[3, Int](0) * cluster_shape.__getitem__[3, Int](1)) * cluster_shape.__getitem__[3, Int](2))), tma_transfer]])

Main computation loop for consumer warp groups.

This function implements the core matrix multiplication using tensor cores. It consumes tiles from the ring buffer and accumulates results using WGMMA (Warp Group Matrix Multiply Accumulate) instructions.

For FP8 data types, it periodically promotes intermediate results to higher precision to maintain accuracy.

Args:

  • wgmma_op (TensorCoreAsync): Tensor core operator for matrix multiplication.
  • local_warp_group_idx (UInt): Index of this consumer warp group (0-based).
  • final_c_reg_tile (LayoutTensor): Final accumulation register tile (for FP8 promotion).
  • c_reg_tile (LayoutTensor): Working accumulation register tile.
  • ring_buffer (RingBufferConsumer): Consumer handle for synchronized tile access.

promote_to_cuda_cores

static promote_to_cuda_cores(c_reg_tile: LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)], final_c_reg_tile: LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)])

Promote FP8 accumulation to higher precision using CUDA cores.

When using FP8 data types, tensor cores accumulate in limited precision. To maintain accuracy over many accumulations, we periodically add the intermediate results to a higher-precision accumulator using CUDA cores.

This technique is commonly used in production libraries like cuBLAS to achieve both high performance (from FP8 tensor cores) and good accuracy.

Args:

  • c_reg_tile (LayoutTensor): Current accumulation from tensor cores.
  • final_c_reg_tile (LayoutTensor): Higher-precision accumulator (updated in place).

wgmma

static wgmma(wgmma_op: TensorCoreAsync[get_accum_type[a_type](), a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: UInt, a_tile: LayoutTensor[a_type, layout, MutableAnyOrigin, address_space=AddressSpace(3), element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_tile: LayoutTensor[b_type, layout, MutableAnyOrigin, address_space=AddressSpace(3), element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5)])

Was this page helpful?