Mojo module
buffers
K, V, Q, P, and Output buffers for RDNA Wave32 attention kernels.
Wave32 WMMA fragment geometry: A/B per lane = 16 elements (full K; lanes 0-15 unique, 16-31 replicate) C/D per lane = 8 elements (MN / WARP = 256 / 32) Lane->C/D map: lane l, elem v -> D[row = v2 + l//16, col = l%16]
comptime valuesβ
RDNA_AB_FRAG_SIZEβ
comptime RDNA_AB_FRAG_SIZE = 16
RDNA_CD_FRAG_SIZEβ
comptime RDNA_CD_FRAG_SIZE = 8
RDNA_WARP_SIZEβ
comptime RDNA_WARP_SIZE = 32
Structsβ
- β
KBufferRDNA: K buffer: holds a (BN, depth) DRAM tile reference, a registerload_tilethat staggers DMA across BK strips, anmma_tilefor the current K fragment, and a (BN, BK) LDS region for the staged strip. - β
OutputRegisterBufferRDNA: Output accumulator register buffer. Layout is (num_n_mmas * num_m_mmas, RDNA_CD_FRAG_SIZE) row_major β one row per MMA tile, one column per per-lane C/D register. - β
PRegisterBufferRDNA: P register buffer (post-softmax scores). Holds the accumulator in registers;copy_to_sharedcasts to dtype and writes to a[BK, BM]SMEM region that the PV phase reads back as A. - β
QRegisterBufferRDNA: Q register buffer: loads each warp's (WM, depth) Q sub-tile into BK-strip MMA fragments at construction. - β
VBufferRDNA: V buffer with transpose-on-LDS-write.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!