IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

softmax_warp

Softmax warp group logic for depth=256/512 pair-CTA SM100 attention.

Computes online softmax over Q@K' scores (S in TMEM) and writes the exponentiated result P to SMEM for SS MMA P@V consumption. Unlike FA4 where P lives in TMEM for TS MMA, this kernel must explicitly transfer P from registers to SMEM with the correct swizzle layout.

Pair-CTA TMEM column layout (cta_group=2): For MMA output [BM, MMA_N]: Columns 0 : MMA_N//2 β†’ TMEM rows 0..BM-1 Columns MMA_N//2 : MMA_N β†’ TMEM rows BM..MMA_M-1

Depth-dependent behavior: split_o=True (d512, MMA_M=128, BM=64): Each M row is served by a thread pair (row m and row m+64). exchange_reduce combines cross-thread row_max/row_sum. split_o=False (d256, MMA_M=256, BM=128): Each thread covers a unique M-row with full BN columns. No exchange needed.

Functions​