Mojo module
softmax_warp
Softmax warp group logic for depth=256/512 pair-CTA SM100 attention.
Computes online softmax over Q@K' scores (S in TMEM) and writes the exponentiated result P to SMEM for SS MMA P@V consumption. Unlike FA4 where P lives in TMEM for TS MMA, this kernel must explicitly transfer P from registers to SMEM with the correct swizzle layout.
Pair-CTA TMEM column layout (cta_group=2): For MMA output [BM, MMA_N]: Columns 0 : MMA_N//2 → TMEM rows 0..BM-1 Columns MMA_N//2 : MMA_N → TMEM rows BM..MMA_M-1
Depth-dependent behavior: split_o=True (d512, MMA_M=128, BM=64): Each M row is served by a thread pair (row m and row m+64). exchange_reduce combines cross-thread row_max/row_sum. split_o=False (d256, MMA_M=256, BM=128): Each thread covers a unique M-row with full BN columns. No exchange needed.
Functions
-
depth512_scale_write_output: Read O from TMEM, scale by inv_row_sum, write to SMEM, TMA store. -
depth512_softmax:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!