Skip to main content

Mojo module

softmax_warp

Softmax warp group logic for depth=512 pair-CTA SM100 attention.

Computes online softmax over Q@K' scores (S in TMEM) and writes the exponentiated result P to SMEM for SS MMA P@V consumption. Unlike FA4 where P lives in TMEM for TS MMA, this kernel must explicitly transfer P from registers to SMEM with the correct swizzle layout.

Pair-CTA TMEM column layout (cta_group=2): For MMA output [BM=64, MMA_N]: Columns 0 : MMA_N//2 → TMEM rows 0-63 (warps 0-1) Columns MMA_N//2 : MMA_N → TMEM rows 64-127 (warps 2-3)

Each M row m is served by a thread pair: thread m (rows 0-63, first
half columns) and thread m+64 (rows 64-127, second half columns).
Full-row row_max and row_sum require cross-thread exchange via the
correction_smem buffer (64 Float32 slots).

Exchange pattern (2 named_barrier syncs per exchange): 1. Lower half writes partial value → sync 2. Upper half reads, computes combined, writes back → sync 3. Lower half reads combined value

Functions

Was this page helpful?