Mojo module
mma_warp
MMA warp logic for depth=256/512 pair-CTA SM100 attention.
Orchestrates Q@K' and P@V matrix multiplications using pair-CTA SS MMA (cta_group=2). Both MMAs read both operands from SMEM, unlike FA4 which uses TS MMA for P@V.
Q@K' produces S in TMEM (double-buffered S_even/S_odd across iterations).
Depth-dependent P@V strategy: split_o=True (depth=512): P@V split into P@V_lo and P@V_hi, each with MMA_N=ov_depth/2. Produces O_lo and O_hi in separate TMEM regions. split_o=False (depth=256): Single P@V with MMA_N=ov_depth. Produces single O in TMEM.
K is sub-staged into num_qk_stages depth chunks (BK0 each). V is sub-staged into num_pv_stages BN chunks (BK1 each). When split_o, V_lo and V_hi are in separate slots (4 total V slots per iteration); otherwise only 2 V slots per iteration.
CTA role split (cta_group=2): Leader CTA (even rank): Owns all pipeline interactions โ waits on KV producer barriers, issues MMA, releases KV consumer barriers with cta_group=2 commit (fences MMA read of both CTAs' SMEM, then signals both CTAs' consumer barriers), and commits S/O barriers. Peer CTA (odd rank): Returns immediately.
Functionsโ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!