Mojo struct
TensorCore
struct TensorCore[out_type: DType, in_type: DType, shape: Index[3], transpose_b: Bool = False]
TensorCore provides an abstraction for GPU tensor core hardware to perform optimized matrix operations.
This struct encapsulates the functionality required to efficiently map matrix operations to Tensor Cores on NVIDIA and AMD GPUs. It handles loading matrix fragments, performing matrix multiply-accumulate operations, and storing results with hardware-specific optimizations.
Note: Different shapes and data types are supported depending on the GPU hardware. For NVIDIA GPUs: - float32: 16×8×8 or 16×8×4 - half-precision: 16×8×16 - float8: 16×8×32 For AMD GPUs: - float32: 16×16×4 - half-precision: 16×16×16
Parameters
- out_type (
DType
): The data type for output/accumulation operations. - in_type (
DType
): The data type for input matrix elements. - shape (
Index[3]
): The shape parameters for the matrix operation in the form [M, N, K] where M×N is the output shape and K is the inner dimension. - transpose_b (
Bool
): Whether to transpose the B matrix before multiplication. Defaults to False.
Aliases
supported_fp32 = (shape == IndexList(16, 8, 8)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 4)) if (in_type == float32) else (in_type == float32)
:supported_half = (shape == IndexList(16, 8, 16)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 16)) if in_type.is_half_float() else in_type.is_half_float()
:supported_fp8 = (shape == IndexList(16, 8, 32)) if Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type) else Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type)
:a_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]
:b_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]
:c_reg_type = SIMD[out_type, num_matrix_reg[::Int,::Int]()]
:c_reg_tile_type = LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), MutableAnyOrigin, address_space=AddressSpace(5)]
:
Implemented traits
AnyType
,
UnknownDestructibility
Methods
__init__
__init__(out self)
Initialize a new TensorCore instance.
get_shapes
static get_shapes[out_type: DType, in_type: DType]() -> List[Index[3]]
Get supported shapes for given data types.
Returns a list of valid shapes for the specified output and input data types.
Note: The returned shapes are hardware-dependent. Different shapes are supported for different combinations of input and output types.
Parameters:
- out_type (
DType
): The output/accumulation data type. - in_type (
DType
): The input matrix data type.
Returns:
List[IndexList[3]]: Valid shapes for the matrix operations given the specified types.
load_a
load_a[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_a_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), MutableAnyOrigin, address_space=AddressSpace(5)]
Load the A matrix fragments.
Loads matrix A from memory into a LayoutTensor suitable for tensor core operations.
Parameters:
- swizzle (
OptionalReg[Swizzle]
): Optional swizzle pattern for optimal memory access (AMD only).
Args:
- a (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The source matrix A data.
Returns:
The loaded matrix fragments as a LayoutTensor
.
load_a[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))
Load A matrix fragments from shared memory.
Optimized version for loading A matrix fragments from shared memory.
Parameters:
- swizzle (
OptionalReg[Swizzle]
): Optional memory access pattern for to optimize memory bandwidth.
Args:
- warp_tile (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The source data in shared memory. - fragments (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The destination tensor for fragments. - mma_tile_coord_k (
UInt
): The K coordinate of the MMA tile. Defaults to 0.
load_b
load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_b_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), MutableAnyOrigin, address_space=AddressSpace(5)]
Load the B matrix fragments.
Loads matrix B from memory into a LayoutTensor
suitable for tensor core operations.
The function handles different hardware architectures and memory access patterns.
Note:
If transpose_b is True
, the B matrix will be transposed during loading.
This is more efficient than transposing the matrix in memory.
Parameters:
- swizzle (
OptionalReg[Swizzle]
): Optional swizzle pattern for optimal memory access (AMD only). Will cause an error if used with NVIDIA GPUs.
Args:
- b (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The source matrix B data.
Returns:
The loaded matrix fragments as a LayoutTensor
.
load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0), warp_tile_coord_n: UInt = UInt(0))
Load B matrix fragments from shared memory into registers for tensor core operations.
This function loads matrix B fragments from a warp tile in shared memory into register fragments for use in tensor core matrix multiply operations. It handles hardware-specific optimizations for both NVIDIA and AMD GPUs.
Note:
The warp_tile
must be in shared memory. For NVIDIA GPUs, swizzle
must be None
.
For AMD GPUs, providing an appropriate swizzle
pattern can improve performance.
Parameters:
- swizzle (
OptionalReg[Swizzle]
): Optional memory access pattern for AMD GPUs to optimize memory bandwidth. Must be None when running on NVIDIA GPUs. For NVIDIA GPUs, swizzle is always on.
Args:
- warp_tile (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): SourceLayoutTensor
in shared memory containing the B matrix data. - fragments (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): DestinationLayoutTensor
to store the loaded matrix fragments. - mma_tile_coord_k (
UInt
): K-dimension coordinate within the warp tile. Defaults to 0. - warp_tile_coord_n (
UInt
): N-dimension coordinate within the warp tile. Defaults to 0.
load_b(self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], scales: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))
Load quantized B matrix fragments from shared memory with dequantization.
This function loads int4 quantized matrix B fragments from shared memory, dequantizes them using the provided scales, and stores the result in register fragments for tensor core operations.
Notes:
- The `warp_tile` must be in shared memory.
- The `fragments` and `scales` must be in local memory.
- This function only supports half-precision data types (bfloat16, float16).
- The quantized data is stored as int4 values packed into int32 elements.
- Each thread processes multiple fragments by unpacking and dequantizing the int4 values.
- The `warp_tile` must be in shared memory.
- The `fragments` and `scales` must be in local memory.
- This function only supports half-precision data types (bfloat16, float16).
- The quantized data is stored as int4 values packed into int32 elements.
- Each thread processes multiple fragments by unpacking and dequantizing the int4 values.
Args:
- warp_tile (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): SourceLayoutTensor
in shared memory containing the quantized B matrix data. - fragments (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): DestinationLayoutTensor
to store the dequantized matrix fragments. - scales (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
):LayoutTensor
containing the scaling factors for dequantization. - mma_tile_coord_k (
UInt
): K-dimension coordinate within the warp tile. Defaults to 0.
load_c
load_c(self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), MutableAnyOrigin, address_space=AddressSpace(5)]
Load the C matrix fragments.
Loads matrix C from memory into a LayoutTensor
suitable for tensor core operations.
The function handles different hardware architectures and memory access patterns.
Args:
- c (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The source matrix C data.
Returns:
The loaded matrix fragments as a LayoutTensor
.
store_d
store_d(self, d_dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])
Store matrix D to destination memory.
Stores the result matrix D from tensor core computation to the destination memory.
Args:
- d_dst (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The destination tensor to store the result. - d_src (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The source tensor containing the computed result.
mma_op
mma_op(self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), MutableAnyOrigin, address_space=AddressSpace(5)]
Perform matrix multiply-accumulate operation (MMA).
Executes D = A * B + C
using tensor cores.
Args:
- a (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The A matrix input. - b (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The B matrix input. - c (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): The C matrix input for accumulation.
Returns:
Self.c_reg_tile_type
: The result of the MMA operation.
mma
mma(self, a_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])
Perform matrix multiply-accumulate operation using tensor cores.
Executes C = A * B + C using tensor cores, where A, B, and C are matrix fragments stored in register memory. This function handles the mapping of fragments to hardware tensor core operations.
Notes:
- All fragments must be properly loaded using the corresponding load functions.
- The function assumes fragments are vectorized layout tensors with dimensions num_vectors x 1.
- The c_frag shape[0] must equal num_m_mmas * num_n_mmas.
- The result is accumulated in-place in c_frag.
- All fragments must be properly loaded using the corresponding load functions.
- The function assumes fragments are vectorized layout tensors with dimensions num_vectors x 1.
- The c_frag shape[0] must equal num_m_mmas * num_n_mmas.
- The result is accumulated in-place in c_frag.
Args:
- a_frag (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix A fragments as aLayoutTensor
. - b_frag (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix B fragments as aLayoutTensor
. - c_frag (
LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix C fragments as aLayoutTensor
for both input and output.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!