Skip to main content
Log in

Mojo struct

TensorCoreAsync

struct TensorCoreAsync[c_type: DType, a_type: DType, b_type: DType, mma_shape: Index[3], /, a_swizzle: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!pop.int_literal](0)), b_swizzle: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!pop.int_literal](0)), transpose_b: Bool = False]

High-performance asynchronous tensor core operations for matrix multiplication.

This struct provides methods for utilizing NVIDIA's Tensor Cores for asynchronous matrix multiplication operations, with support for various data types and swizzling configurations.

Parameters

  • c_type (DType): Data type of the output matrix C.
  • a_type (DType): Data type of the input matrix A.
  • b_type (DType): Data type of the input matrix B.
  • mma_shape (Index[3]): Dimensions for the matrix multiply-accumulate (MMA) operation as [M, N, K].
  • a_swizzle (TensorMapSwizzle): Swizzling mode for matrix A (default: SWIZZLE_NONE).
  • b_swizzle (TensorMapSwizzle): Swizzling mode for matrix B (default: SWIZZLE_NONE).
  • transpose_b (Bool): Whether to transpose matrix B (default: False).

Implemented traits

AnyType, UnknownDestructibility

Methods

__init__

__init__(out self)

Initialize the TensorCoreAsync instance.

Ensures that the provided MMA shape is supported.

Note: Fails to compile if mma_shape is not supported.

wgmma

static wgmma[num_warp_groups: Int = 1, scale_c: Int = 1, scale_a: Int = 1, scale_b: Int = 1](a_smem_tile: LayoutTensor[a_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], wg_idx: Int = 0)

Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA).

This method handles the case where both A and B matrices are in shared memory.

Parameters:

  • num_warp_groups (Int): Number of warp groups to distribute work across (default: 1).
  • scale_c (Int): Scale factor for matrix C. Valid values are 1 or 0 (default: 1).
  • scale_a (Int): Scale factor for matrix A. Valid values are 1 or -1 (default: 1).
  • scale_b (Int): Scale factor for matrix B. Valid values are 1 or -1 (default: 1).

Args:

  • a_smem_tile (LayoutTensor[a_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix A in shared memory.
  • b_smem_tile (LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix B in shared memory.
  • c_reg_tile (LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Output matrix C in register memory.
  • wg_idx (Int): Warp group index for multi-warp group scenarios (default: 0).

static wgmma(a_frag: LayoutTensor[a_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])

Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA).

This overloaded method handles the case where matrix A is in register memory and matrix B is in shared memory.

Args:

  • a_frag (LayoutTensor[a_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix A in register memory.
  • b_smem_tile (LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix B in shared memory.
  • c_reg_tile (LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Output matrix C in register memory.

arrive

static arrive()

Ensures memory consistency by creating a fence for WGMMA operations.

This method should be called before committing a group to ensure all shared memory accesses are properly aligned and visible.

commit_group

static commit_group()

Commits the current warp group for execution.

This synchronizes the warp group and commits all pending WGMMA operations that have been previously issued.

wait_group

static wait_group[group: Int = 0]()

Waits for the completion of a specific warp group's operations.

This method blocks until all WGMMA operations from the specified group are complete.

Parameters:

  • group (Int): The group ID to wait for (default: 0).