Skip to main content
Log in

Mojo function

mma

mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: SIMD[uint32, 1] = __init__[__mlir_type.!pop.int_literal](1)](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: SIMD[uint32, 1], inst_desc: UMMAInsDescriptor[kind])

Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.

Parameters:

  • kind (UMMAKind): Data type of the matrices.
  • cta_group (Int): Number of ctas used by MMA.
  • c_scale (SIMD[uint32, 1]): Scale factor for the C matrix, 0 or 1.

Args:

  • a_desc (MMASmemDescriptor): The descriptor for the A matrix.
  • b_desc (MMASmemDescriptor): The descriptor for the B matrix.
  • c_tmem (SIMD[uint32, 1]): The address of the C matrix in the tensor memory.
  • inst_desc (UMMAInsDescriptor[kind]): The descriptor for the MMA instruction.

mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: SIMD[uint32, 1] = __init__[__mlir_type.!pop.int_literal](1)](a_desc: SIMD[uint32, 1], b_desc: MMASmemDescriptor, c_tmem: SIMD[uint32, 1], inst_desc: UMMAInsDescriptor[kind])

Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.

Parameters:

  • kind (UMMAKind): Data type of the matrices.
  • cta_group (Int): Number of ctas used by MMA.
  • c_scale (SIMD[uint32, 1]): Scale factor for the C matrix, 0 or 1.

Args:

  • a_desc (SIMD[uint32, 1]): The descriptor for the A matrix.
  • b_desc (MMASmemDescriptor): The descriptor for the B matrix.
  • c_tmem (SIMD[uint32, 1]): The address of the C matrix in the tensor memory.
  • inst_desc (UMMAInsDescriptor[kind]): The descriptor for the MMA instruction.