Mojo function
wgmma_async
wgmma_async[m: Int, n: Int, k: Int, c_dtype: DType, width: Int, /, *, a_type: DType, b_type: DType, accum_type: DType = $3, layout_a: StringLiteral = "row", layout_b: StringLiteral = "col"](mat_a_desc: WGMMADescriptor[dtype], mat_b_desc: WGMMADescriptor[dtype], c_reg: SIMD[c_dtype, width]) -> SIMD[c_dtype, width]
Performs warp group async Matrix-multiply and accumulate (WGMMA) operation.
This function executes an asynchronous matrix multiplication using warp group MMA instructions. It supports various data types including tensor float32, bfloat16, float16, float8, int8, and uint8.
Constraints:
- The number of output registers must match the instruction shape:
(m * n // 128) * sizeof(accum_type) == width * sizeof(c_dtype)
. - Data type combinations must be compatible with hardware WGMMA instructions.
Parameters:
- m (
Int
): Number of rows in matrix A and output matrix. - n (
Int
): Number of columns in matrix B and output matrix. - k (
Int
): Number of columns in matrix A / rows in matrix B. - c_dtype (
DType
): Data type of the output matrix C. - width (
Int
): Width of the SIMD register for matrix C. - a_type (
DType
): Data type of matrix A. - b_type (
DType
): Data type of matrix B. - accum_type (
DType
): Accumulation data type (defaults to c_dtype). - layout_a (
StringLiteral
): Memory layout for matrix A ("row" or "col"). - layout_b (
StringLiteral
): Memory layout for matrix B ("row" or "col").
Args:
- mat_a_desc (
WGMMADescriptor[dtype]
): WGMMA descriptor for matrix A. - mat_b_desc (
WGMMADescriptor[dtype]
): WGMMA descriptor for matrix B. - c_reg (
SIMD[c_dtype, width]
): SIMD register containing matrix C values.
Returns:
SIMD register containing the result of the matrix multiplication.
wgmma_async[m: Int, n: Int, k: Int, a_dtype: DType, c_dtype: DType, frag_a_width: Int, frag_c_width: Int, /, *, a_type: DType, b_type: DType, accum_type: DType = $4, layout_a: StringLiteral = "row", layout_b: StringLiteral = "col"](mat_a_frag: SIMD[a_dtype, frag_a_width], mat_b_desc: WGMMADescriptor[dtype], c: SIMD[c_dtype, frag_c_width]) -> SIMD[c_dtype, frag_c_width]
Performs warp group async Matrix-multiply and accumulate (WGMMA) operation.
Currently only supports:
- m=64, k=16.
- BF16 input types.
- FP32 accumulation.
- Row major matrix A.
- Column major matrix B (or row major for BF16).
Parameters:
- m (
Int
): Number of rows in output matrix. - n (
Int
): Number of columns in output matrix. - k (
Int
): Inner dimension for matrix multiplication. - a_dtype (
DType
): Data type of matrix A fragment. - c_dtype (
DType
): Data type of output matrix C. - frag_a_width (
Int
): Width of matrix A fragment. - frag_c_width (
Int
): Width of output matrix C fragment. - a_type (
DType
): Data type of matrix A. - b_type (
DType
): Data type of matrix B. - accum_type (
DType
): Data type used for accumulation (defaults to c_dtype). - layout_a (
StringLiteral
): Layout of matrix A ("row" or "col", defaults to "row"). - layout_b (
StringLiteral
): Layout of matrix B ("row" or "col", defaults to "col").
Args:
- mat_a_frag (
SIMD[a_dtype, frag_a_width]
): Fragment containing matrix A data. - mat_b_desc (
WGMMADescriptor[dtype]
): Descriptor for matrix B data. - c (
SIMD[c_dtype, frag_c_width]
): Fragment containing matrix C data.
Returns:
Updated matrix C fragment after WGMMA operation.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!