Mojo struct
TensorCoreAsync
struct TensorCoreAsync[c_type: DType, a_type: DType, b_type: DType, mma_shape: Index[3], /, a_swizzle: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!pop.int_literal](0)), b_swizzle: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!pop.int_literal](0)), transpose_b: Bool = False]
High-performance asynchronous tensor core operations for matrix multiplication.
This struct provides methods for utilizing NVIDIA's Tensor Cores for asynchronous matrix multiplication operations, with support for various data types and swizzling configurations.
Parameters
- c_type (
DType
): Data type of the output matrix C. - a_type (
DType
): Data type of the input matrix A. - b_type (
DType
): Data type of the input matrix B. - mma_shape (
Index[3]
): Dimensions for the matrix multiply-accumulate (MMA) operation as [M, N, K]. - a_swizzle (
TensorMapSwizzle
): Swizzling mode for matrix A (default: SWIZZLE_NONE). - b_swizzle (
TensorMapSwizzle
): Swizzling mode for matrix B (default: SWIZZLE_NONE). - transpose_b (
Bool
): Whether to transpose matrix B (default: False).
Implemented traits
AnyType
,
UnknownDestructibility
Methods
__init__
__init__(out self)
Initialize the TensorCoreAsync
instance.
Ensures that the provided MMA shape is supported.
Note:
Fails to compile if mma_shape
is not supported.
wgmma
static wgmma[num_warp_groups: Int = 1, scale_c: Int = 1, scale_a: Int = 1, scale_b: Int = 1](a_smem_tile: LayoutTensor[a_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], wg_idx: Int = 0)
Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA).
This method handles the case where both A and B matrices are in shared memory.
Parameters:
- num_warp_groups (
Int
): Number of warp groups to distribute work across (default: 1). - scale_c (
Int
): Scale factor for matrix C. Valid values are 1 or 0 (default: 1). - scale_a (
Int
): Scale factor for matrix A. Valid values are 1 or -1 (default: 1). - scale_b (
Int
): Scale factor for matrix B. Valid values are 1 or -1 (default: 1).
Args:
- a_smem_tile (
LayoutTensor[a_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix A in shared memory. - b_smem_tile (
LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix B in shared memory. - c_reg_tile (
LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Output matrix C in register memory. - wg_idx (
Int
): Warp group index for multi-warp group scenarios (default: 0).
static wgmma(a_frag: LayoutTensor[a_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])
Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA).
This overloaded method handles the case where matrix A is in register memory and matrix B is in shared memory.
Args:
- a_frag (
LayoutTensor[a_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix A in register memory. - b_smem_tile (
LayoutTensor[b_type, layout, origin, address_space=AddressSpace(3), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Matrix B in shared memory. - c_reg_tile (
LayoutTensor[c_type, layout, origin, address_space=AddressSpace(5), element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]
): Output matrix C in register memory.
arrive
static arrive()
Ensures memory consistency by creating a fence for WGMMA operations.
This method should be called before committing a group to ensure all shared memory accesses are properly aligned and visible.
commit_group
static commit_group()
Commits the current warp group for execution.
This synchronizes the warp group and commits all pending WGMMA operations that have been previously issued.
wait_group
static wait_group[group: Int = 0]()
Waits for the completion of a specific warp group's operations.
This method blocks until all WGMMA operations from the specified group are complete.
Parameters:
- group (
Int
): The group ID to wait for (default: 0).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!