Skip to main content

Mojo function

mma_rdna

mma_rdna[c_register_buffer_type: RegisterBuffer, a_register_buffer_type: RegisterMMABuffer, b_buffer_type: KVBuffer, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BK: Int, prefetch_function: OptionalReg[fn() capturing -> None], swap_a_b: Bool = False, beg_iter: Int = 0, num_iters: Int = 1, prefetched_b_tile: Bool = False, a_copy_fn: fn[i: Int]() capturing -> None = _noop_copy_fn](c: c_register_buffer_type, mut a_tile: a_register_buffer_type, mut b_tile: b_buffer_type)

RDNA-specific MMA operation for Wave32 WMMA.

This function performs matrix multiply-accumulate operations using RDNA's 16x16x16 WMMA instructions. It handles the K-dimension tiling and manages shared memory staging for the B operand.

Parameters:

  • c_register_buffer_type (RegisterBuffer): Type for C accumulator buffer (8-element fragments).
  • a_register_buffer_type (RegisterMMABuffer): Type for A input buffer (16-element fragments).
  • b_buffer_type (KVBuffer): Type for B input buffer loaded from shared memory.
  • tensor_core_mma (TiledTensorCore): The TiledTensorCore configuration for RDNA.
  • BK (Int): Block size in K dimension.
  • prefetch_function (OptionalReg): Optional function to prefetch next tile.
  • swap_a_b (Bool): Whether to swap A and B operands.
  • beg_iter (Int): Starting iteration index.
  • num_iters (Int): Number of iterations over tiles.
  • prefetched_b_tile (Bool): Whether B tile is already prefetched.
  • a_copy_fn (fn[i: Int]() capturing -> None): Callback to copy A (P) chunk i to shared memory.

Args:

  • c (c_register_buffer_type): Accumulator register buffer.
  • a_tile (a_register_buffer_type): A operand register buffer.
  • b_tile (b_buffer_type): B operand buffer (loaded to shared memory).

Was this page helpful?