Skip to main content

Mojo module

softmax_warp

Softmax warp group logic for depth=256/512 pair-CTA SM100 attention.

Computes online softmax over Q@K' scores (S in TMEM) and writes the exponentiated result P to SMEM for SS MMA P@V consumption. Unlike FA4 where P lives in TMEM for TS MMA, this kernel must explicitly transfer P from registers to SMEM with the correct swizzle layout.

Pair-CTA TMEM column layout (cta_group=2): For MMA output [BM, MMA_N]: Columns 0 : MMA_N//2 → TMEM rows 0..BM-1 Columns MMA_N//2 : MMA_N → TMEM rows BM..MMA_M-1

Depth-dependent behavior: split_o=True (d512, MMA_M=128, BM=64): Each M row is served by a thread pair (row m and row m+64). exchange_reduce combines cross-thread row_max/row_sum. split_o=False (d256, MMA_M=256, BM=128): Each thread covers a unique M-row with full BN columns. No exchange needed.

Functions

Was this page helpful?