Skip to main content

Mojo module

buffers

K, V, Q, P, and Output buffers for RDNA Wave32 attention kernels.

Wave32 WMMA fragment geometry: A/B per lane = 16 elements (full K; lanes 0-15 unique, 16-31 replicate) C/D per lane = 8 elements (MN / WARP = 256 / 32) Lane->C/D map: lane l, elem v -> D[row = v2 + l//16, col = l%16]

comptime values​

RDNA_AB_FRAG_SIZE​

comptime RDNA_AB_FRAG_SIZE = 16

RDNA_CD_FRAG_SIZE​

comptime RDNA_CD_FRAG_SIZE = 8

RDNA_WARP_SIZE​

comptime RDNA_WARP_SIZE = 32

Structs​

  • ​KBufferRDNA: K buffer: holds a (BN, depth) DRAM tile reference, a register load_tile that staggers DMA across BK strips, an mma_tile for the current K fragment, and a (BN, BK) LDS region for the staged strip.
  • ​OutputRegisterBufferRDNA: Output accumulator register buffer. Layout is (num_n_mmas * num_m_mmas, RDNA_CD_FRAG_SIZE) row_major β€” one row per MMA tile, one column per per-lane C/D register.
  • ​PRegisterBufferRDNA: P register buffer (post-softmax scores). Holds the accumulator in registers; copy_to_shared casts to dtype and writes to a [BK, BM] SMEM region that the PV phase reads back as A.
  • ​QRegisterBufferRDNA: Q register buffer: loads each warp's (WM, depth) Q sub-tile into BK-strip MMA fragments at construction.
  • ​VBufferRDNA: V buffer with transpose-on-LDS-write.