Skip to main content
Log in

Mojo struct

TensorCore

struct TensorCore[out_type: DType, in_type: DType, shape: Index[3], transpose_b: Bool = False]

TensorCore provides an abstraction for GPU tensor core hardware to perform optimized matrix operations.

This struct encapsulates the functionality required to efficiently map matrix operations to Tensor Cores on NVIDIA and AMD GPUs. It handles loading matrix fragments, performing matrix multiply-accumulate operations, and storing results with hardware-specific optimizations.

Note: Different shapes and data types are supported depending on the GPU hardware. For NVIDIA GPUs: - float32: 16×8×8 or 16×8×4 - half-precision: 16×8×16 - float8: 16×8×32 For AMD GPUs: - float32: 16×16×4 - half-precision: 16×16×16

Parameters

  • out_type (DType): The data type for output/accumulation operations.
  • in_type (DType): The data type for input matrix elements.
  • shape (Index[3]): The shape parameters for the matrix operation in the form [M, N, K] where M×N is the output shape and K is the inner dimension.
  • transpose_b (Bool): Whether to transpose the B matrix before multiplication. Defaults to False.

Aliases

  • supported_fp32 = (shape == IndexList(16, 8, 8)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 4)) if (in_type == float32) else (in_type == float32):
  • supported_half = (shape == IndexList(16, 8, 16)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 16)) if in_type.is_half_float() else in_type.is_half_float():
  • supported_fp8 = (shape == IndexList(16, 8, 32)) if Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type) else Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type):
  • a_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]:
  • b_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]:
  • c_reg_type = SIMD[out_type, num_matrix_reg[::Int,::Int]()]:
  • c_reg_tile_type = LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), MutableAnyOrigin, address_space=AddressSpace(5)]:

Implemented traits

AnyType, UnknownDestructibility

Methods

__init__

__init__(out self)

Initialize a new TensorCore instance.

get_shapes

static get_shapes[out_type: DType, in_type: DType]() -> List[Index[3]]

Get supported shapes for given data types.

Returns a list of valid shapes for the specified output and input data types.

Note: The returned shapes are hardware-dependent. Different shapes are supported for different combinations of input and output types.

Parameters:

  • out_type (DType): The output/accumulation data type.
  • in_type (DType): The input matrix data type.

Returns:

List[IndexList[3]]: Valid shapes for the matrix operations given the specified types.

load_a

load_a[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_a_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), MutableAnyOrigin, address_space=AddressSpace(5)]

Load the A matrix fragments.

Loads matrix A from memory into a LayoutTensor suitable for tensor core operations.

Parameters:

  • swizzle (OptionalReg[Swizzle]): Optional swizzle pattern for optimal memory access (AMD only).

Args:

  • a (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The source matrix A data.

Returns:

The loaded matrix fragments as a LayoutTensor.

load_a[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))

Load A matrix fragments from shared memory.

Optimized version for loading A matrix fragments from shared memory.

Parameters:

  • swizzle (OptionalReg[Swizzle]): Optional memory access pattern for to optimize memory bandwidth.

Args:

  • warp_tile (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The source data in shared memory.
  • fragments (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The destination tensor for fragments.
  • mma_tile_coord_k (UInt): The K coordinate of the MMA tile. Defaults to 0.

load_b

load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_b_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), MutableAnyOrigin, address_space=AddressSpace(5)]

Load the B matrix fragments.

Loads matrix B from memory into a LayoutTensor suitable for tensor core operations. The function handles different hardware architectures and memory access patterns.

Note: If transpose_b is True, the B matrix will be transposed during loading. This is more efficient than transposing the matrix in memory.

Parameters:

  • swizzle (OptionalReg[Swizzle]): Optional swizzle pattern for optimal memory access (AMD only). Will cause an error if used with NVIDIA GPUs.

Args:

  • b (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The source matrix B data.

Returns:

The loaded matrix fragments as a LayoutTensor.

load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0), warp_tile_coord_n: UInt = UInt(0))

Load B matrix fragments from shared memory into registers for tensor core operations.

This function loads matrix B fragments from a warp tile in shared memory into register fragments for use in tensor core matrix multiply operations. It handles hardware-specific optimizations for both NVIDIA and AMD GPUs.

Note: The warp_tile must be in shared memory. For NVIDIA GPUs, swizzle must be None. For AMD GPUs, providing an appropriate swizzle pattern can improve performance.

Parameters:

  • swizzle (OptionalReg[Swizzle]): Optional memory access pattern for AMD GPUs to optimize memory bandwidth. Must be None when running on NVIDIA GPUs. For NVIDIA GPUs, swizzle is always on.

Args:

  • warp_tile (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Source LayoutTensor in shared memory containing the B matrix data.
  • fragments (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Destination LayoutTensor to store the loaded matrix fragments.
  • mma_tile_coord_k (UInt): K-dimension coordinate within the warp tile. Defaults to 0.
  • warp_tile_coord_n (UInt): N-dimension coordinate within the warp tile. Defaults to 0.

load_b(self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], scales: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))

Load quantized B matrix fragments from shared memory with dequantization.

This function loads int4 quantized matrix B fragments from shared memory, dequantizes them using the provided scales, and stores the result in register fragments for tensor core operations.

Notes:

- The `warp_tile` must be in shared memory.
- The `fragments` and `scales` must be in local memory.
- This function only supports half-precision data types (bfloat16, float16).
- The quantized data is stored as int4 values packed into int32 elements.
- Each thread processes multiple fragments by unpacking and dequantizing the int4 values.
- The `warp_tile` must be in shared memory.
- The `fragments` and `scales` must be in local memory.
- This function only supports half-precision data types (bfloat16, float16).
- The quantized data is stored as int4 values packed into int32 elements.
- Each thread processes multiple fragments by unpacking and dequantizing the int4 values.

Args:

  • warp_tile (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Source LayoutTensor in shared memory containing the quantized B matrix data.
  • fragments (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Destination LayoutTensor to store the dequantized matrix fragments.
  • scales (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): LayoutTensor containing the scaling factors for dequantization.
  • mma_tile_coord_k (UInt): K-dimension coordinate within the warp tile. Defaults to 0.

load_c

load_c(self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), MutableAnyOrigin, address_space=AddressSpace(5)]

Load the C matrix fragments.

Loads matrix C from memory into a LayoutTensor suitable for tensor core operations. The function handles different hardware architectures and memory access patterns.

Args:

  • c (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The source matrix C data.

Returns:

The loaded matrix fragments as a LayoutTensor.

store_d

store_d(self, d_dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])

Store matrix D to destination memory.

Stores the result matrix D from tensor core computation to the destination memory.

Args:

  • d_dst (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The destination tensor to store the result.
  • d_src (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The source tensor containing the computed result.

mma_op

mma_op(self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), MutableAnyOrigin, address_space=AddressSpace(5)]

Perform matrix multiply-accumulate operation (MMA).

Executes D = A * B + C using tensor cores.

Args:

  • a (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The A matrix input.
  • b (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The B matrix input.
  • c (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): The C matrix input for accumulation.

Returns:

Self.c_reg_tile_type: The result of the MMA operation.

mma

mma(self, a_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])

Perform matrix multiply-accumulate operation using tensor cores.

Executes C = A * B + C using tensor cores, where A, B, and C are matrix fragments stored in register memory. This function handles the mapping of fragments to hardware tensor core operations.

Notes:

- All fragments must be properly loaded using the corresponding load functions.
- The function assumes fragments are vectorized layout tensors with dimensions num_vectors x 1.
- The c_frag shape[0] must equal num_m_mmas * num_n_mmas.
- The result is accumulated in-place in c_frag.
- All fragments must be properly loaded using the corresponding load functions.
- The function assumes fragments are vectorized layout tensors with dimensions num_vectors x 1.
- The c_frag shape[0] must equal num_m_mmas * num_n_mmas.
- The result is accumulated in-place in c_frag.

Args:

  • a_frag (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix A fragments as a LayoutTensor.
  • b_frag (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix B fragments as a LayoutTensor.
  • c_frag (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]): Matrix C fragments as a LayoutTensor for both input and output.