Skip to main content

Mojo struct

HopperMatmulSM90Kernel

struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_smem_layout: Layout, block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Int, num_threads: Int = 128, transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = 1, pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, hilbert_swizzle: Bool = False, k_group_size: Int = 1, swapAB: Bool = False]

Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs.

This kernel implements a highly optimized matrix multiplication (GEMM) using:

  • Tensor Memory Accelerator (TMA) for efficient global-to-shared memory transfers
  • Warp Group Matrix Multiply Accumulate (WGMMA) instructions for tensor cores
  • Multi-stage software pipelining for overlapping compute and memory operations
  • Producer-consumer model with separate warp groups for loading and computing

Template Parameters: a_type, b_type, c_type: Data types for input and output matrices a_layout, b_layout, c_layout: Memory layouts for matrices c_smem_layout: Shared memory layout for output tile block_tile_shape: Tile dimensions [M, N, K] processed by each thread block wgmma_shape: Dimensions for each WGMMA instruction [M, N, K] cluster_shape: Thread block cluster dimensions for distributed shared memory num_pipeline_stages: Number of stages in the software pipeline (typically 3-7) num_threads: Number of threads per block (must be multiple of 128) transpose_b: Whether B matrix is transposed (required to be True) a_swizzle, b_swizzle: Memory swizzling for bank-conflict-free access c_swizzle: Swizzling for output writes partitioned_multicast: Enable partitioned multicast for large tiles use_tma_store: Use TMA for storing output (vs regular stores) promotion_frequency: How often to promote FP8 accumulation to higher precision pdl_level: Programmatic Dependency Launch (PDL) level elementwise_lambda_fn: Optional epilogue function elementwise_compute_lambda_fn: Optional compute function hilbert_swizzle: Use Hilbert curve for thread block scheduling

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

a_smem_layout​

comptime a_smem_layout = tile_layout_k_major[a_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, a_swizzle]()

accum_type​

comptime accum_type = get_accum_type[a_type]()

AccumRegTile​

comptime AccumRegTile = LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]

adjusted_num_pipeline_stages​

comptime adjusted_num_pipeline_stages = (num_pipeline_stages // k_group_size)

b_smem_layout​

comptime b_smem_layout = tile_layout_k_major[b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, b_swizzle]()

BK​

comptime BK = block_tile_shape[2]

BM​

comptime BM = block_tile_shape[0]

BN​

comptime BN = block_tile_shape[1]

c_frag_size​

comptime c_frag_size = ((wgmma_shape[0] * wgmma_shape[1]) // 128)

cluster_size​

comptime cluster_size = Int[Int32](((cluster_shape[0] * cluster_shape[1]) * cluster_shape[2]))

num_consumer​

comptime num_consumer = ((num_threads // 128) - 1)

num_consumer_threads​

comptime num_consumer_threads = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer * 128)

num_m_mmas​

comptime num_m_mmas = ((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // wgmma_shape[0]) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer)

num_n_mmas​

comptime num_n_mmas = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // wgmma_shape[1])

SMem​

comptime SMem = HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, c_smem_layout.shape[0].value(), c_smem_layout.shape[1].value(), num_pipeline_stages, k_group_size]

TMABarrier​

comptime TMABarrier = TMABarrierHandler[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].SMem.ATileArray.storage_size + HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].SMem.BTileArray.storage_size) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages)]

WgmmaOp​

comptime WgmmaOp = TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]

Methods​

validate_constraints​

static validate_constraints()

Validate common constraints for all kernel variants.

pipeline_init​

static pipeline_init()

Initialize pipeline synchronization barriers.

This function ensures that all pipeline initialization (barriers, shared memory) is visible to all thread blocks in the cluster before proceeding. This is critical for correct producer-consumer synchronization.

For multi-cluster configurations, uses fence and cluster sync. For single block, uses a simple barrier.

finalize_kernel​

static finalize_kernel()

Common finalization for all kernel variants.

multicast_mask​

static multicast_mask(rank_m: Int, rank_n: Int) -> Tuple[Int32, Int32]

Returns:

Tuple[Int32, Int32]

common_kernel_init​

static common_kernel_init() -> Tuple[Int, Int, Int, Int, Int, Bool]

Common initialization for all kernel variants.

Returns:

Tuple[Int, Int, Int, Int, Int, Bool]: Tuple of (warp_group_idx, warp_group_thread_idx, rank_m, rank_n, warp_id, lane_predicate).

setup_producer​

static setup_producer() -> Int

Setup producer warp group by deallocating registers.

Returns:

Int: Number of registers deallocated.

setup_consumer​

static setup_consumer(warp_group_idx: Int) -> Tuple[Int, LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]]

Setup consumer warp group.

Returns:

Tuple[Int, LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]]: Tuple of (local_warp_group_idx, c_reg_tile, final_c_reg_tile).

consumer_arrive_empty_barriers​

static consumer_arrive_empty_barriers(warp_group_thread_idx: Int, mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages])

Signal initial empty barrier arrival for all pipeline stages.

Must be called by consumer warp groups before the main loop so the producer knows it can start filling stages.

get_block_swizzle​

static get_block_swizzle(lut_ptr: OptionalReg[UnsafePointer[UInt32, MutAnyOrigin]] = None) -> IndexList[2, element_type=DType.uint32]

Calculate block swizzle for better L2 cache locality.

Args:

Returns:

IndexList[2, element_type=DType.uint32]: Swizzled block indices.

consumer_output​

static consumer_output[custom_elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = elementwise_lambda_fn](c_tma_op: TMATensorTile[c_type], c: LayoutTensor[c_type, MutAnyOrigin, address_space=c.address_space, element_layout=c.element_layout, layout_int_type=c.layout_int_type, linear_idx_type=c.linear_idx_type, masked=c.masked, alignment=c.alignment], c_tile: TileTensor[c_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], output_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], warp_group_thread_idx: Int, local_warp_group_idx: Int, local_thread_idx: Int, block_y: Int, block_x: Int)

Handle consumer output by writing GEMM results to global memory.

build_tma_loaders​

static build_tma_loaders[a_tma_rank: Int, b_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], //](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], rank_m: Int, rank_n: Int) -> Tuple[TileLoaderTMA[origin_of(a_tma_op), a_type, a_tma_rank, a_tile_shape, a_desc_shape, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape[0], use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[origin_of(b_tma_op), b_type, b_tma_rank, b_tile_shape, b_desc_shape, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape[1], use_partitioned_multicast=partitioned_multicast]]

Returns:

Tuple[TileLoaderTMA[origin_of(a_tma_op), a_type, a_tma_rank, a_tile_shape, a_desc_shape, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape[0], use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[origin_of(b_tma_op), b_type, b_tma_rank, b_tile_shape, b_desc_shape, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape[1], use_partitioned_multicast=partitioned_multicast]]

build_cpasync_loaders​

static build_cpasync_loaders[k_align: Int, vector_size: Int = (k_align // size_of[a_type]()), num_threads_per_row: Int = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK // vector_size), thread_layout: Layout = Layout.row_major((WARPGROUP_SIZE // num_threads_per_row), num_threads_per_row)](a: LayoutTensor[a_type, a_layout, ImmutAnyOrigin], b: LayoutTensor[b_type, b_layout, ImmutAnyOrigin]) -> Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]

Returns:

Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]

producer_main_loop_pipeline​

static producer_main_loop_pipeline[a_loader_type: TileLoader, b_loader_type: TileLoader, barrier_handler_type: BarrierHandler, //, num_k_iters: Int](m_coord: Int, n_coord: Int, k_coord: Int, a_loader: a_loader_type, b_loader: b_loader_type, barrier_handler: barrier_handler_type, mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]()), Coord(Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]()), Coord(Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))), num_pipeline_stages])

run​

static run[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank]](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], a: LayoutTensor[a_type, a_layout, ImmutAnyOrigin], b: LayoutTensor[b_type, b_layout, ImmutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], lut_ptr: UnsafePointer[UInt32, MutAnyOrigin])

Main kernel entry point for matrix multiplication.

This kernel implements a producer-consumer pattern where:

  • One warp group (producer) loads tiles from global memory using TMA
  • Multiple warp groups (consumers) perform matrix multiplication using tensor cores

The kernel uses software pipelining to overlap memory transfers with computation, achieving high throughput on Hopper GPUs.

Args:

run_splitk​

static run_splitk[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], splits: Int, raster_order: RasterOrder](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], workspace_ptr: UnsafePointer[Scalar[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type], MutAnyOrigin], locks_ptr: UnsafePointer[UInt8, MutAnyOrigin], problem_shape: IndexList[3])

Split-K variant of the kernel for better load balancing on small problems.

run_grouped​

static run_grouped[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], AOffsetsLayout: TensorLayout, ExpertIdsLayout: TensorLayout](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], a_offsets: TileTensor[DType.uint32, AOffsetsLayout, ImmutAnyOrigin], expert_ids: TileTensor[DType.int32, ExpertIdsLayout, ImmutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])

Grouped matmul variant for MoE (Mixture of Experts) models.

This variant handles multiple experts where each expert processes a subset of tokens. The a_offsets array indicates token boundaries for each expert.

consumer_main_loop_pipeline​

static consumer_main_loop_pipeline[num_k_iters: Int](wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Int, final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]()), Coord(Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]()), Coord(Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))), num_pipeline_stages], warp_group_thread_idx: Int)

Pipeline-based consumer loop using ProducerConsumerPipeline.

This is an alternative implementation of consumer_main_loop that uses the SM100 ProducerConsumerPipeline for synchronization instead of RingBuffer.

Args:

promote_to_cuda_cores​

static promote_to_cuda_cores(c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])

Promote FP8 accumulation to higher precision using CUDA cores.

When using FP8 data types, tensor cores accumulate in limited precision. To maintain accuracy over many accumulations, we periodically add the intermediate results to a higher-precision accumulator using CUDA cores.

This technique is commonly used in production libraries like cuBLAS to achieve both high performance (from FP8 tensor cores) and good accuracy.

Args:

wgmma​

static wgmma(wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Int, a_tile: TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], b_tile: TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])