Mojo module
mma_warp
MMA warp logic for depth=512 pair-CTA SM100 attention.
Orchestrates Q@K' and P@V matrix multiplications using pair-CTA SS MMA (cta_group=2). Both MMAs read both operands from SMEM, unlike FA4 which uses TS MMA for P@V.
Q@K' produces S in TMEM (double-buffered S_even/S_odd across iterations).
P@V is split into two MMA groups: P@V_lo (MMA_N=ov_depth/2): produces O_lo in TMEM cols [0, ov_depth/2) P@V_hi (MMA_N=ov_depth/2): produces O_hi in TMEM cols [ov_depth/2, ov_depth)
The split enables pipelining: after V_lo completes, O_mma_lo fires and the correction warp can start rescaling O_lo while V_hi sub-stages are still running. V_lo and V_hi occupy separate KV pipeline slots, each holding a [BK1, ov_depth/4] half-tile per CTA (with cta_group=2, each CTA contributes ov_depth/4 columns to the B operand).
K is sub-staged into num_qk_stages=4 depth chunks (BK0 each). V is sub-staged into num_pv_stages=2 BN chunks (BK1 each), with V_lo and V_hi in separate slots (4 total V slots per iteration). Both use the fused KV pipeline where K and V sub-tiles share buffer slots.
CTA role split (cta_group=2): Leader CTA (even rank): Owns all pipeline interactions โ waits on KV producer barriers, issues MMA, releases KV consumer barriers with cta_group=2 commit (fences MMA read of both CTAs' SMEM, then signals both CTAs' consumer barriers), and commits S/O barriers. Peer CTA (odd rank): Returns immediately.
Functionsโ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!