Mojo struct
HopperMatmulSM90Kernel
struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_smem_layout: Layout, block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Int, num_threads: Int = 128, transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = 1, pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, hilbert_swizzle: Bool = False, k_group_size: Int = 1, swapAB: Bool = False]
Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs.
This kernel implements a highly optimized matrix multiplication (GEMM) using:
- Tensor Memory Accelerator (TMA) for efficient global-to-shared memory transfers
- Warp Group Matrix Multiply Accumulate (WGMMA) instructions for tensor cores
- Multi-stage software pipelining for overlapping compute and memory operations
- Producer-consumer model with separate warp groups for loading and computing
Template Parameters: a_type, b_type, c_type: Data types for input and output matrices a_layout, b_layout, c_layout: Memory layouts for matrices c_smem_layout: Shared memory layout for output tile block_tile_shape: Tile dimensions [M, N, K] processed by each thread block wgmma_shape: Dimensions for each WGMMA instruction [M, N, K] cluster_shape: Thread block cluster dimensions for distributed shared memory num_pipeline_stages: Number of stages in the software pipeline (typically 3-7) num_threads: Number of threads per block (must be multiple of 128) transpose_b: Whether B matrix is transposed (required to be True) a_swizzle, b_swizzle: Memory swizzling for bank-conflict-free access c_swizzle: Swizzling for output writes partitioned_multicast: Enable partitioned multicast for large tiles use_tma_store: Use TMA for storing output (vs regular stores) promotion_frequency: How often to promote FP8 accumulation to higher precision pdl_level: Programmatic Dependency Launch (PDL) level elementwise_lambda_fn: Optional epilogue function elementwise_compute_lambda_fn: Optional compute function hilbert_swizzle: Use Hilbert curve for thread block scheduling
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
a_smem_layoutβ
comptime a_smem_layout = tile_layout_k_major[a_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, a_swizzle]()
accum_typeβ
comptime accum_type = get_accum_type[a_type]()
AccumRegTileβ
comptime AccumRegTile = LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]
adjusted_num_pipeline_stagesβ
comptime adjusted_num_pipeline_stages = (num_pipeline_stages // k_group_size)
b_smem_layoutβ
comptime b_smem_layout = tile_layout_k_major[b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, b_swizzle]()
BKβ
comptime BK = block_tile_shape[2]
BMβ
comptime BM = block_tile_shape[0]
BNβ
comptime BN = block_tile_shape[1]
c_frag_sizeβ
comptime c_frag_size = ((wgmma_shape[0] * wgmma_shape[1]) // 128)
cluster_sizeβ
comptime cluster_size = Int[Int32](((cluster_shape[0] * cluster_shape[1]) * cluster_shape[2]))
num_consumerβ
comptime num_consumer = ((num_threads // 128) - 1)
num_consumer_threadsβ
comptime num_consumer_threads = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer * 128)
num_m_mmasβ
comptime num_m_mmas = ((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // wgmma_shape[0]) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer)
num_n_mmasβ
comptime num_n_mmas = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // wgmma_shape[1])
SMemβ
comptime SMem = HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, c_smem_layout.shape[0].value(), c_smem_layout.shape[1].value(), num_pipeline_stages, k_group_size]
TMABarrierβ
comptime TMABarrier = TMABarrierHandler[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].SMem.ATileArray.storage_size + HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].SMem.BTileArray.storage_size) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages)]
WgmmaOpβ
comptime WgmmaOp = TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]
Methodsβ
validate_constraintsβ
static validate_constraints()
Validate common constraints for all kernel variants.
pipeline_initβ
static pipeline_init()
Initialize pipeline synchronization barriers.
This function ensures that all pipeline initialization (barriers, shared memory) is visible to all thread blocks in the cluster before proceeding. This is critical for correct producer-consumer synchronization.
For multi-cluster configurations, uses fence and cluster sync. For single block, uses a simple barrier.
finalize_kernelβ
static finalize_kernel()
Common finalization for all kernel variants.
multicast_maskβ
common_kernel_initβ
static common_kernel_init() -> Tuple[Int, Int, Int, Int, Int, Bool]
Common initialization for all kernel variants.
Returns:
Tuple[Int, Int, Int, Int, Int, Bool]: Tuple of (warp_group_idx, warp_group_thread_idx,
rank_m, rank_n, warp_id, lane_predicate).
setup_producerβ
static setup_producer() -> Int
Setup producer warp group by deallocating registers.
Returns:
Int: Number of registers deallocated.
setup_consumerβ
static setup_consumer(warp_group_idx: Int) -> Tuple[Int, LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]]
Setup consumer warp group.
Returns:
consumer_arrive_empty_barriersβ
static consumer_arrive_empty_barriers(warp_group_thread_idx: Int, mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages])
Signal initial empty barrier arrival for all pipeline stages.
Must be called by consumer warp groups before the main loop so the producer knows it can start filling stages.
get_block_swizzleβ
static get_block_swizzle(lut_ptr: OptionalReg[UnsafePointer[UInt32, MutAnyOrigin]] = None) -> IndexList[2, element_type=DType.uint32]
Calculate block swizzle for better L2 cache locality.
Args:
- βlut_ptr (
OptionalReg[UnsafePointer[UInt32, MutAnyOrigin]]): Lookup table for Hilbert curve block scheduling (optional).
Returns:
IndexList[2, element_type=DType.uint32]: Swizzled block indices.
consumer_outputβ
static consumer_output[custom_elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = elementwise_lambda_fn](c_tma_op: TMATensorTile[c_type], c: LayoutTensor[c_type, MutAnyOrigin, address_space=c.address_space, element_layout=c.element_layout, layout_int_type=c.layout_int_type, linear_idx_type=c.linear_idx_type, masked=c.masked, alignment=c.alignment], c_tile: TileTensor[c_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], output_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], warp_group_thread_idx: Int, local_warp_group_idx: Int, local_thread_idx: Int, block_y: Int, block_x: Int)
Handle consumer output by writing GEMM results to global memory.
build_tma_loadersβ
static build_tma_loaders[a_tma_rank: Int, b_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], //](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], rank_m: Int, rank_n: Int) -> Tuple[TileLoaderTMA[origin_of(a_tma_op), a_type, a_tma_rank, a_tile_shape, a_desc_shape, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape[0], use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[origin_of(b_tma_op), b_type, b_tma_rank, b_tile_shape, b_desc_shape, BK=HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK, cluster_size=cluster_shape[1], use_partitioned_multicast=partitioned_multicast]]
Returns:
build_cpasync_loadersβ
static build_cpasync_loaders[k_align: Int, vector_size: Int = (k_align // size_of[a_type]()), num_threads_per_row: Int = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK // vector_size), thread_layout: Layout = Layout.row_major((WARPGROUP_SIZE // num_threads_per_row), num_threads_per_row)](a: LayoutTensor[a_type, a_layout, ImmutAnyOrigin], b: LayoutTensor[b_type, b_layout, ImmutAnyOrigin]) -> Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]
Returns:
producer_main_loop_pipelineβ
static producer_main_loop_pipeline[a_loader_type: TileLoader, b_loader_type: TileLoader, barrier_handler_type: BarrierHandler, //, num_k_iters: Int](m_coord: Int, n_coord: Int, k_coord: Int, a_loader: a_loader_type, b_loader: b_loader_type, barrier_handler: barrier_handler_type, mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]()), Coord(Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]()), Coord(Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))), num_pipeline_stages])
runβ
static run[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank]](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], a: LayoutTensor[a_type, a_layout, ImmutAnyOrigin], b: LayoutTensor[b_type, b_layout, ImmutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], lut_ptr: UnsafePointer[UInt32, MutAnyOrigin])
Main kernel entry point for matrix multiplication.
This kernel implements a producer-consumer pattern where:
- One warp group (producer) loads tiles from global memory using TMA
- Multiple warp groups (consumers) perform matrix multiplication using tensor cores
The kernel uses software pipelining to overlap memory transfers with computation, achieving high throughput on Hopper GPUs.
Args:
- βa_tma_op (
TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape]): TMA descriptor for matrix A. - βb_tma_op (
TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape]): TMA descriptor for matrix B. - βc_tma_op (
TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape]): TMA descriptor for matrix C. - βa (
LayoutTensor[a_type, a_layout, ImmutAnyOrigin]): Input matrix A. - βb (
LayoutTensor[b_type, b_layout, ImmutAnyOrigin]): Input matrix B. - βc (
LayoutTensor[c_type, c_layout, MutAnyOrigin]): Output matrix C. - βlut_ptr (
UnsafePointer[UInt32, MutAnyOrigin]): Lookup table for Hilbert curve block scheduling (optional).
run_splitkβ
static run_splitk[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], splits: Int, raster_order: RasterOrder](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], workspace_ptr: UnsafePointer[Scalar[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type], MutAnyOrigin], locks_ptr: UnsafePointer[UInt8, MutAnyOrigin], problem_shape: IndexList[3])
Split-K variant of the kernel for better load balancing on small problems.
run_groupedβ
static run_grouped[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], AOffsetsLayout: TensorLayout, ExpertIdsLayout: TensorLayout](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], a_offsets: TileTensor[DType.uint32, AOffsetsLayout, ImmutAnyOrigin], expert_ids: TileTensor[DType.int32, ExpertIdsLayout, ImmutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])
Grouped matmul variant for MoE (Mixture of Experts) models.
This variant handles multiple experts where each expert processes a subset of tokens. The a_offsets array indicates token boundaries for each expert.
consumer_main_loop_pipelineβ
static consumer_main_loop_pipeline[num_k_iters: Int](wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Int, final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], mut pipeline: ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]()), Coord(Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]()), Coord(Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))), num_pipeline_stages], warp_group_thread_idx: Int)
Pipeline-based consumer loop using ProducerConsumerPipeline.
This is an alternative implementation of consumer_main_loop that uses the SM100 ProducerConsumerPipeline for synchronization instead of RingBuffer.
Args:
- βwgmma_op (
TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]): Tensor core operator for matrix multiplication. - βlocal_warp_group_idx (
Int): Index of this consumer warp group (0-based). - βfinal_c_reg_tile (
LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]): Final accumulation register tile (for FP8 promotion). - βc_reg_tile (
LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]): Working accumulation register tile. - βpipeline (
ProducerConsumerPipeline[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages]): ProducerConsumerPipeline for synchronized tile access. - βa_tiles (
SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM // 8)]()), Coord(Idx[(128 // size_of[a_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[a_type]())](), Idx[(8 * (128 // size_of[a_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[a_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BM * (128 // size_of[a_type]()))]()))), num_pipeline_stages]): Tile array for A matrix in shared memory. - βb_tiles (
SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(Idx[8](), Idx[(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN // 8)]()), Coord(Idx[(128 // size_of[b_type]())](), Idx[((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128)]())), Coord(Coord(Idx[(128 // size_of[b_type]())](), Idx[(8 * (128 // size_of[b_type]()))]()), Coord(Idx[1](), Idx[0 if (((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BK * size_of[b_type]()) // 128) == 1) else (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].BN * (128 // size_of[b_type]()))]()))), num_pipeline_stages]): Tile array for B matrix in shared memory. - βwarp_group_thread_idx (
Int): Thread index within the warp group.
promote_to_cuda_coresβ
static promote_to_cuda_cores(c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])
Promote FP8 accumulation to higher precision using CUDA cores.
When using FP8 data types, tensor cores accumulate in limited precision. To maintain accuracy over many accumulations, we periodically add the intermediate results to a higher-precision accumulator using CUDA cores.
This technique is commonly used in production libraries like cuBLAS to achieve both high performance (from FP8 tensor cores) and good accuracy.
Args:
- βc_reg_tile (
LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]): Current accumulation from tensor cores. - βfinal_c_reg_tile (
LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]): Higher-precision accumulator (updated in place).
wgmmaβ
static wgmma(wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Int, a_tile: TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], b_tile: TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!