IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mha_softmax

Online softmax row-state bundle for MhaPrefillV2.

OnlineSoftmax (below) owns the four row-state scalars maintained by the FlashAttention-2 online-softmax recurrence (max_vec, max_vec_prev, norm_vec, scale_vec) as Float32 fields and exposes the recurrence steps + the cross-lane reductions and column-broadcast ops as static methods. Cluster fns in mha_prefill_v2.mojo thread a single OnlineSoftmax reference instead of four RegTile parameters.

The col_l rt_32x32 accumulator topology gives each lane ownership of one column of the 32x32 fragment, which corresponds to one Q row in the warp's stripe. The online-softmax recurrence therefore tracks one running max + one running norm + one pending scale per lane β€” each of the four pieces of state is a single FP32 scalar in a VGPR.

Each column of an rt_32x32 is held redundantly across two half-warps (lanes [0, 32) and [32, 64)). Per-column reductions combine in-lane via SIMD.reduce_* and then across half-warps via permlane_swap[32] β€” a single-cycle DPP-style swap. Using stdlib's lane_group_reduce here would lower to ds_bpermute_b32 (LDS-routed), so we go through permlane_swap directly.

Structs​