For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
mha_softmax
Online softmax row-state bundle for MhaPrefillV2.
OnlineSoftmax (below) owns the four row-state scalars maintained by
the FlashAttention-2 online-softmax recurrence (max_vec,
max_vec_prev, norm_vec, scale_vec) as Float32 fields and exposes
the recurrence steps + the cross-lane reductions and column-broadcast
ops as static methods. Cluster fns in mha_prefill_v2.mojo thread a
single OnlineSoftmax reference instead of four RegTile parameters.
The col_l rt_32x32 accumulator topology gives each lane ownership of one column of the 32x32 fragment, which corresponds to one Q row in the warp's stripe. The online-softmax recurrence therefore tracks one running max + one running norm + one pending scale per lane β each of the four pieces of state is a single FP32 scalar in a VGPR.
Each column of an rt_32x32 is held redundantly across two half-warps
(lanes [0, 32) and [32, 64)). Per-column reductions combine in-lane
via SIMD.reduce_* and then across half-warps via permlane_swap[32]
β a single-cycle DPP-style swap. Using stdlib's lane_group_reduce
here would lower to ds_bpermute_b32 (LDS-routed), so we go through
permlane_swap directly.
Structsβ
- β
OnlineSoftmax: Online softmax row-state bundle forMhaPrefillV2.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!