Skip to main content

Mojo module

tensor_core_async

Tensor Core Async Module

This module provides high-performance abstractions for utilizing NVIDIA's Tensor Cores to perform asynchronous matrix multiplication operations. It implements optimized memory layouts and access patterns for efficient tensor core computations.

Key components:

  • Layout creation functions for K-major and MN-major memory arrangements
  • Swizzling support for improved memory access patterns
  • WGMMA (Warp Group Matrix Multiply-Accumulate) descriptor generation
  • TensorCoreAsync struct with methods for asynchronous matrix multiplication

The module supports various data types, matrix dimensions, and memory configurations, enabling efficient implementation of deep learning primitives and other tensor operations that can leverage hardware acceleration.

Performance features:

  • Asynchronous execution model to overlap computation and memory access
  • Support for different swizzling modes to optimize memory bandwidth
  • Efficient register and shared memory utilization
  • Support for multi-warp group execution

This implementation is specifically optimized for NVIDIA GPUs with Tensor Core support.

comptime values

tile_layout_k_major_typed

comptime tile_layout_k_major_typed[dtype: DType, BM: Int, BK: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE] = Layout(Coord(Coord(Idx[8](), Idx[(BM // 8)]()), Coord(Idx[(swizzle_mode.bytes() // size_of[dtype]())](), Idx[(BK // (swizzle_mode.bytes() // size_of[dtype]()))]())), Coord(Coord(Idx[(swizzle_mode.bytes() // size_of[dtype]())](), Idx[(8 * (swizzle_mode.bytes() // size_of[dtype]()))]()), Coord(Idx[1](), Idx[0 if (BK == (swizzle_mode.bytes() // size_of[dtype]())) else (BM * (swizzle_mode.bytes() // size_of[dtype]()))]())))

K-major typed Layout for tensor core operations.

Shape ((CM, BM/CM), (sw_K, BK/sw_K)), stride ((sw_K, CM*sw_K), (1, BM*sw_K)) where CM=8 and sw_K = swizzle_mode.bytes() / sizeof(dtype). When BK/sw_K == 1, the outer K stride is 0 (compact).

Parameters

  • dtype (DType): Element data type of the tensor.
  • BM (Int): Size of the M dimension in the tile.
  • BK (Int): Size of the K dimension in the tile.
  • swizzle_mode (TensorMapSwizzle): Memory access pattern swizzling mode (default: SWIZZLE_NONE).

tile_layout_mn_major_typed

comptime tile_layout_mn_major_typed[dtype: DType, mn_dim: Int, k_dim: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE] = Layout(Coord(Coord(Idx[8](), Idx[(k_dim // 8)]()), Coord(Idx[(swizzle_mode.bytes() // size_of[dtype]())](), Idx[(mn_dim // (swizzle_mode.bytes() // size_of[dtype]()))]())), Coord(Coord(Idx[(swizzle_mode.bytes() // size_of[dtype]())](), Idx[(8 * (swizzle_mode.bytes() // size_of[dtype]()))]()), Coord(Idx[1](), Idx[0 if (mn_dim == (swizzle_mode.bytes() // size_of[dtype]())) else (k_dim * (swizzle_mode.bytes() // size_of[dtype]()))]()))).transpose()

MN-major typed Layout for tensor core operations.

Equivalent to tile_layout_k_major_typed[dtype, k_dim, mn_dim, swizzle_mode].transpose().

Parameters

  • dtype (DType): Element data type of the tensor.
  • mn_dim (Int): Size of the MN dimension.
  • k_dim (Int): Size of the K dimension.
  • swizzle_mode (TensorMapSwizzle): Memory access pattern swizzling mode (default: SWIZZLE_NONE).

WGMMA_K_BYTES

comptime WGMMA_K_BYTES = 32

Size of WGMMA K dimension in bytes.

Structs

  • TensorCoreAsync: High-performance asynchronous tensor core operations for matrix multiplication.

Functions

Was this page helpful?