IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

HopperMatmulSM90Kernel

struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout, c_smem_layout: Layout, block_tile_shape: IndexList[Int(3)], wgmma_shape: IndexList[Int(3)], cluster_shape: StaticTuple[Int32, Int(3)], num_pipeline_stages: Int, num_threads: Int = Int(128), transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = Int(1), pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, hilbert_swizzle: Bool = False, k_group_size: Int = Int(1), swapAB: Bool = False]

Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs.

This kernel implements a highly optimized matrix multiplication (GEMM) using:

  • Tensor Memory Accelerator (TMA) for efficient global-to-shared memory transfers
  • Warp Group Matrix Multiply Accumulate (WGMMA) instructions for tensor cores
  • Multi-stage software pipelining for overlapping compute and memory operations
  • Producer-consumer model with separate warp groups for loading and computing

Template Parameters: a_type, b_type, c_type: Data types for input and output matrices a_layout, b_layout, c_layout: Memory layouts for matrices c_smem_layout: Shared memory layout for output tile block_tile_shape: Tile dimensions [M, N, K] processed by each thread block wgmma_shape: Dimensions for each WGMMA instruction [M, N, K] cluster_shape: Thread block cluster dimensions for distributed shared memory num_pipeline_stages: Number of stages in the software pipeline (typically 3-7) num_threads: Number of threads per block (must be multiple of 128) transpose_b: Whether B matrix is transposed (required to be True) a_swizzle, b_swizzle: Memory swizzling for bank-conflict-free access c_swizzle: Swizzling for output writes partitioned_multicast: Enable partitioned multicast for large tiles use_tma_store: Use TMA for storing output (vs regular stores) promotion_frequency: How often to promote FP8 accumulation to higher precision pdl_level: Programmatic Dependency Launch (PDL) level elementwise_lambda_fn: Optional epilogue function elementwise_compute_lambda_fn: Optional compute function hilbert_swizzle: Use Hilbert curve for thread block scheduling

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

a_smem_layout​

comptime a_smem_layout = tile_layout_k_major[a_type, block_tile_shape[Int(0)], block_tile_shape[Int(2)], a_swizzle]()

accum_type​

comptime accum_type = get_accum_type[a_type]()

AccumRegTile​

comptime AccumRegTile = LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL]

adjusted_num_pipeline_stages​

comptime adjusted_num_pipeline_stages = (num_pipeline_stages // k_group_size)

b_smem_layout​

comptime b_smem_layout = tile_layout_k_major[b_type, block_tile_shape[Int(1)], block_tile_shape[Int(2)], b_swizzle]()

BK​

comptime BK = block_tile_shape[Int(2)]

BM​

comptime BM = block_tile_shape[Int(0)]

BN​

comptime BN = block_tile_shape[Int(1)]

c_frag_size​

comptime c_frag_size = (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))

cluster_size​

comptime cluster_size = SIMD(((cluster_shape[Int(0)] * cluster_shape[Int(1)]) * cluster_shape[Int(2)]))

num_consumer​

comptime num_consumer = ((num_threads // Int(128)) - Int(1))

num_consumer_threads​

comptime num_consumer_threads = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer * Int(128))

num_m_mmas​

comptime num_m_mmas = ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].num_consumer)

num_n_mmas​

comptime num_n_mmas = (block_tile_shape[Int(1)] // wgmma_shape[Int(1)])

SMem​

comptime SMem = HopperMatmulSM90Kernel_SMem[a_type, b_type, c_type, block_tile_shape[Int(0)], block_tile_shape[Int(1)], block_tile_shape[Int(2)], c_smem_layout.shape[0].value(), c_smem_layout.shape[1].value(), num_pipeline_stages, k_group_size]

TMABarrier​

comptime TMABarrier = TMABarrierHandler[(Int((add (mul Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))).product(), size_of[a_type](), num_pipeline_stages), (mul Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))).product(), size_of[b_type](), num_pipeline_stages))) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].adjusted_num_pipeline_stages)]

WgmmaOp​

comptime WgmmaOp = TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle, k_group_size, swapAB].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]

Methods​

validate_constraints​

static def validate_constraints()

Validate common constraints for all kernel variants.

pipeline_init​

static def pipeline_init()

Initialize pipeline synchronization barriers.

This function ensures that all pipeline initialization (barriers, shared memory) is visible to all thread blocks in the cluster before proceeding. This is critical for correct producer-consumer synchronization.

For multi-cluster configurations, uses fence and cluster sync. For single block, uses a simple barrier.

finalize_kernel​

static def finalize_kernel()

Common finalization for all kernel variants.

multicast_mask​

static def multicast_mask(rank_m: Int, rank_n: Int) -> Tuple[Int32, Int32]

Returns:

Tuple[Int32, Int32]

common_kernel_init​

static def common_kernel_init() -> Tuple[Int, Int, Int, Int, Int, Bool]

Common initialization for all kernel variants.

Returns:

Tuple[Int, Int, Int, Int, Int, Bool]: Tuple of (warp_group_idx, warp_group_thread_idx, rank_m, rank_n, warp_id, lane_predicate).

setup_producer​

static def setup_producer() -> Int

Setup producer warp group by deallocating registers.

Returns:

Int: Number of registers deallocated.

setup_consumer​

static def setup_consumer(warp_group_idx: Int) -> Tuple[Int, LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL], LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL]]

Setup consumer warp group.

Returns:

Tuple[Int, LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL], LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL]]: Tuple of (local_warp_group_idx, c_reg_tile, final_c_reg_tile).

consumer_arrive_empty_barriers​

static def consumer_arrive_empty_barriers(warp_group_thread_idx: Int, mut pipeline: ProducerConsumerPipeline[Self.adjusted_num_pipeline_stages])

Signal initial empty barrier arrival for all pipeline stages.

Must be called by consumer warp groups before the main loop so the producer knows it can start filling stages.

get_block_swizzle​

static def get_block_swizzle(lut_ptr: OptionalReg[UnsafePointer[UInt32, MutAnyOrigin]] = None) -> IndexList[Int(2), element_type=DType.uint32]

Calculate block swizzle for better L2 cache locality.

Args:

Returns:

IndexList[Int(2), element_type=DType.uint32]: Swizzled block indices.

consumer_output​

static def consumer_output[custom_elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = elementwise_lambda_fn](c_tma_op: TMATensorTile[c_type], c: TileTensor[c_type, Storage=c.Storage, linear_idx_type=c.linear_idx_type], c_tile: TileTensor[c_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], output_reg_tile: LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL], warp_group_thread_idx: Int, local_warp_group_idx: Int, local_thread_idx: Int, block_y: Int, block_x: Int)

Handle consumer output by writing GEMM results to global memory.

build_tma_loaders​

static def build_tma_loaders[a_tma_rank: Int, b_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], //](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], rank_m: Int, rank_n: Int) -> Tuple[TileLoaderTMA[origin_of(a_tma_op), a_type, a_tma_rank, a_tile_shape, a_desc_shape, BK=block_tile_shape[Int(2)], cluster_size=cluster_shape[Int(0)], use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[origin_of(b_tma_op), b_type, b_tma_rank, b_tile_shape, b_desc_shape, BK=block_tile_shape[Int(2)], cluster_size=cluster_shape[Int(1)], use_partitioned_multicast=partitioned_multicast]]

Returns:

Tuple[TileLoaderTMA[origin_of(a_tma_op), a_type, a_tma_rank, a_tile_shape, a_desc_shape, BK=block_tile_shape[Int(2)], cluster_size=cluster_shape[Int(0)], use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[origin_of(b_tma_op), b_type, b_tma_rank, b_tile_shape, b_desc_shape, BK=block_tile_shape[Int(2)], cluster_size=cluster_shape[Int(1)], use_partitioned_multicast=partitioned_multicast]]

build_cpasync_loaders​

static def build_cpasync_loaders[k_align: Int, vector_size: Int = (k_align // size_of[a_type]()), num_threads_per_row: Int = (block_tile_shape[Int(2)] // vector_size), thread_layout: Layout[thread_layout.shape_types, thread_layout.stride_types] = row_major[(_resolve_warpgroup_size() // num_threads_per_row), num_threads_per_row]()](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin]) -> Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]

Returns:

Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]

producer_main_loop_pipeline​

static def producer_main_loop_pipeline[a_loader_type: TileLoader, b_loader_type: TileLoader, barrier_handler_type: BarrierHandler, //, num_k_iters: Int](m_coord: Int, n_coord: Int, k_coord: Int, a_loader: a_loader_type, b_loader: b_loader_type, barrier_handler: barrier_handler_type, mut pipeline: ProducerConsumerPipeline[Self.adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))), num_pipeline_stages])

run​

static def run[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], a_tensor_layout: TensorLayout, b_tensor_layout: TensorLayout, c_tensor_layout: TensorLayout](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], a: TileTensor[a_type, a_tensor_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_tensor_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_tensor_layout, MutAnyOrigin], lut_ptr: UnsafePointer[UInt32, MutAnyOrigin])

Main kernel entry point for matrix multiplication.

This kernel implements a producer-consumer pattern where:

  • One warp group (producer) loads tiles from global memory using TMA
  • Multiple warp groups (consumers) perform matrix multiplication using tensor cores

The kernel uses software pipelining to overlap memory transfers with computation, achieving high throughput on Hopper GPUs.

Args:

run_splitk​

static def run_splitk[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], splits: Int, raster_order: RasterOrder, c_tensor_layout: TensorLayout](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], c: TileTensor[c_type, c_tensor_layout, MutAnyOrigin], workspace_ptr: UnsafePointer[Scalar[Self.accum_type], MutAnyOrigin], locks_ptr: UnsafePointer[UInt8, MutAnyOrigin], problem_shape: IndexList[Int(3)])

Split-K variant of the kernel for better load balancing on small problems.

run_grouped​

static def run_grouped[a_tma_rank: Int, b_tma_rank: Int, c_tma_rank: Int, a_tile_shape: IndexList[a_tma_rank], b_tile_shape: IndexList[b_tma_rank], c_tile_shape: IndexList[c_tma_rank], a_desc_shape: IndexList[a_tma_rank], b_desc_shape: IndexList[b_tma_rank], c_desc_shape: IndexList[c_tma_rank], AOffsetsLayout: TensorLayout, ExpertIdsLayout: TensorLayout, c_tensor_layout: TensorLayout](a_tma_op: TMATensorTile[a_type, a_tma_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tma_rank, b_tile_shape, b_desc_shape], c_tma_op: TMATensorTile[c_type, c_tma_rank, c_tile_shape, c_desc_shape], a_offsets: TileTensor[DType.uint32, AOffsetsLayout, ImmutAnyOrigin], expert_ids: TileTensor[DType.int32, ExpertIdsLayout, ImmutAnyOrigin], c: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])

Grouped matmul variant for MoE (Mixture of Experts) models.

This variant handles multiple experts where each expert processes a subset of tokens. The a_offsets array indicates token boundaries for each expert.

consumer_main_loop_pipeline​

static def consumer_main_loop_pipeline[num_k_iters: Int](wgmma_op: TensorCoreAsync[Self.accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Int, final_c_reg_tile: LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL], c_reg_tile: LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL], mut pipeline: ProducerConsumerPipeline[Self.adjusted_num_pipeline_stages], a_tiles: SMemTileArrayWithLayout[a_type, Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))), num_pipeline_stages], b_tiles: SMemTileArrayWithLayout[b_type, Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt()))), num_pipeline_stages], warp_group_thread_idx: Int)

Pipeline-based consumer loop using ProducerConsumerPipeline.

This is an alternative implementation of consumer_main_loop that uses the SM100 ProducerConsumerPipeline for synchronization instead of RingBuffer.

Args:

promote_to_cuda_cores​

static def promote_to_cuda_cores(c_reg_tile: LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL], final_c_reg_tile: LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL])

Promote FP8 accumulation to higher precision using CUDA cores.

When using FP8 data types, tensor cores accumulate in limited precision. To maintain accuracy over many accumulations, we periodically add the intermediate results to a higher-precision accumulator using CUDA cores.

This technique is commonly used in production libraries like cuBLAS to achieve both high performance (from FP8 tensor cores) and good accuracy.

Args:

wgmma​

static def wgmma(wgmma_op: TensorCoreAsync[Self.accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: Int, a_tile: TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], b_tile: TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], c_reg_tile: LayoutTensor[Self.accum_type, Layout.row_major(Int((mul (block_tile_shape[Int(1)] // wgmma_shape[Int(1)]), ((block_tile_shape[Int(0)] // wgmma_shape[Int(0)]) // Int((add (num_threads // Int(128)), -1))))), (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // Int(128))), MutAnyOrigin, address_space=AddressSpace.LOCAL])