Mojo function
mma_rdna
mma_rdna[c_register_buffer_type: RegisterBuffer, a_register_buffer_type: RegisterMMABuffer, b_buffer_type: KVBuffer, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BK: Int, prefetch_function: OptionalReg[fn() capturing -> None], swap_a_b: Bool = False, beg_iter: Int = 0, num_iters: Int = 1, prefetched_b_tile: Bool = False, a_copy_fn: fn[i: Int]() capturing -> None = _noop_copy_fn](c: c_register_buffer_type, mut a_tile: a_register_buffer_type, mut b_tile: b_buffer_type)
RDNA-specific MMA operation for Wave32 WMMA.
This function performs matrix multiply-accumulate operations using RDNA's 16x16x16 WMMA instructions. It handles the K-dimension tiling and manages shared memory staging for the B operand.
Parameters:
- c_register_buffer_type (
RegisterBuffer): Type for C accumulator buffer (8-element fragments). - a_register_buffer_type (
RegisterMMABuffer): Type for A input buffer (16-element fragments). - b_buffer_type (
KVBuffer): Type for B input buffer loaded from shared memory. - tensor_core_mma (
TiledTensorCore): The TiledTensorCore configuration for RDNA. - BK (
Int): Block size in K dimension. - prefetch_function (
OptionalReg): Optional function to prefetch next tile. - swap_a_b (
Bool): Whether to swap A and B operands. - beg_iter (
Int): Starting iteration index. - num_iters (
Int): Number of iterations over tiles. - prefetched_b_tile (
Bool): Whether B tile is already prefetched. - a_copy_fn (
fn[i: Int]() capturing -> None): Callback to copy A (P) chunk i to shared memory.
Args:
- c (
c_register_buffer_type): Accumulator register buffer. - a_tile (
a_register_buffer_type): A operand register buffer. - b_tile (
b_buffer_type): B operand buffer (loaded to shared memory).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!