IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

OnlineSoftmax

struct OnlineSoftmax[att_dtype: DType = DType.float32]

Online softmax row-state bundle for MhaPrefillV2.

Parametrized on att_dtype β€” the dtype of the att_block RegTile this state operates against (FP32 for the BF16 prefill path and FP8 + KV<128; FP16 for FP8 + KV>=128 per _SOFTMAX_DTYPE in MhaPrefillV2 / MlaPrefillV2). All att_block-touching methods bind to Self.att_dtype so the type checker rejects mismatched dtypes at the call site instead of silently coercing. The four state scalars stay Float32 regardless β€” see "Accumulator dtype rationale" below.

Owns the four row-state scalars maintained by the FlashAttention-2 online-softmax recurrence as direct Float32 fields:

  • max_vec β€” running rowmax (in log2 units; the reference prescales Q by scale * log2(e) so att values are already in log2 units).
  • max_vec_prev β€” rowmax from the previous tile. Used by the lazy-rescale comparison and the unconditional rescale's exp2(prev - new). Shadow-updated to max_vec after each consumed cluster.
  • norm_vec β€” running denominator (exp-sum so far). Consumed at Epi-C12's o_reg /= norm_vec.
  • scale_vec β€” pending rescale factor exp2(max_prev - max_new). Conditionally applied to o_reg during lazy-rescale (main loop) or unconditionally during the epilogue tails. Reset to 1 when no rescale fired so norm_vec *= scale_vec is a safe identity.

Each field is a per-lane FP32 scalar (1 VGPR/lane). The col_l rt_32x32 topology gives each lane ownership of one column of the fragment = one Q row in the warp's stripe, so per-lane scalar state is the natural representation. (Each column is held redundantly across the two half-warps sharing it β€” both lanes store their own copy of the identical reduced value.)

Lifetime (MhaPrefillV2):

  • max_vec, max_vec_prev, scale_vec β€” prologue β†’ Epi-C10 (last touched by _full_softmax_unconditional + the final rescale_output(o_reg)).
  • norm_vec β€” prologue β†’ Epi-C12 (consumed by normalize_output(o_reg), three clusters after the other three die).

norm_vec outliving the other three by two clusters costs 1 VGPR over Epi-C11/C12; not register-pressure-constrained there.

Accumulator dtype rationale (why state stays Float32 regardless of att_dtype): the four scalars are per-lane (1 VGPR/lane each), so narrowing to FP16 saves zero VGPRs (registers are dword-granular; a half-VGPR scalar costs pack/unpack ALU on every read). norm_vec sums up to ~seq_len terms of exp2(att - max) ∈ [0, 1]; at seq=8192 the sum reaches ~2^13, leaving little FP16 headroom. Hardware v_exp_f32 / v_rcp_f32 is the FP32 path. Narrowing att_dtype to FP16 only pays off on the larger att_block tile storage, not on the recurrence scalars.

Fields​

  • ​max_vec (Float32):
  • ​max_vec_prev (Float32):
  • ​norm_vec (Float32):
  • ​scale_vec (Float32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable

Methods​

__init__​

def __init__(out self)

No-sink init: max_vec/max_vec_prev/norm_vec to 0, scale_vec to 1 so the epilogue's unconditional norm_vec *= scale_vec is a safe no-op when no rescale fired.

reseed_with_sink​

def reseed_with_sink(mut self, sw_log2: Float32)

Re-init for the sink path: pre-seed the recurrence with the virtual sink token's contribution.

max_vec = max_vec_prev = log2e * sink_weight keeps the rowmax in log2 units (the reference prescales Q by scale * log2e, so att values are in log2 units). norm_vec = 1 reflects the virtual sink's exp2(score - max) = exp2(0) = 1 contribution. Subsequent tiles update through the normal recurrence; the sink is rescaled implicitly as the running max grows. scale_vec stays at the 1 set by __init__.

seed_tile0​

def seed_tile0[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

Prologue tile-0 partial softmax setup.

Three-step composite: max_vec = col_max(att_block), max_vec_prev = max_vec, att_block -= max_vec. Called once after the first QK + mask, before the prologue's first-half exp2. There's no col_max_acc because there's no prior rowmax yet (no-sink path) or the sink contribution is already seeded into max_vec (sink path).

col_max_acc​

def col_max_acc[layout: TensorLayout, //](mut self, att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

Running rowmax: max_vec = max(max_vec_prev, col_max(att_block)).

Caller maintains the max_vec_prev = max_vec shadow-write separately (via lazy_rescale_decision or update_scale_unconditional) β€” splitting the shadow-write out lets the cluster fns interpose IGLP barriers between the rowmax update and the scale update without dragging max_vec_prev along.

sub_max​

def sub_max[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

att_block -= max_vec per element. Prepares att_block for the subsequent exp2_inplace_range call.

seed_tile0_scaled​

def seed_tile0_scaled[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], log2_scale: Float32)

Scale-folded seed_tile0 for the FP32-in-place sequential cadence (_FP32_SOFTMAX_SCORES). att_block arrives RAW (un-scaled QK); log2_scale = scale * log2(e).

max_vec = log2_scale * col_max(raw) (running max kept in the SCALED log2 domain, so update_scale_unconditional's exp2(prev - new) rescale is byte-identical to the non-folded path), max_vec_prev = max_vec, then att = log2_scale*att - max_vec via one fused v_pk_fma per fragment. The col_max runs in place on the raw 64-VGPR att (no transient scaled copy) and the scale never touches the exp2 critical path. Math identical: scale > 0 β‡’ max(scaleΒ·x) = scaleΒ·max(x).

Uses _col_max_scalar_v3max (v_max3_f32-folded, FP32-scores- only) instead of _col_max_scalar β€” same column max, only the in-lane fold grouping changes so ISel emits v_max3_f32.

col_max_acc_scaled​

def col_max_acc_scaled[layout: TensorLayout, //](mut self, att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], log2_scale: Float32)

Scale-folded col_max_acc: max_vec = max(max_prev, log2_scale * col_max(raw att)). Keeps the running max in the scaled log2 domain while att_block stays RAW. See seed_tile0_scaled.

Uses _col_max_scalar_v3max (v_max3_f32-folded) for the raw column max. The running-max fold-in stays the OUTER max(...) here rather than a seed into the chain: max_prev is in the SCALED log2 domain but the chain reduces RAW scores, so seeding the chain with max_prev would compare across domains. The log2_scale * raw_max then max(max_prev, ...) ordering keeps the rescale exp2(prev - new) byte-identical to the non-folded path; the outer max is one residual v_max_f32 per tile.

sub_max_scaled​

def sub_max_scaled[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], log2_scale: Float32)

Scale-folded sub_max: att = log2_scale*att - max_vec via one fused v_pk_fma per fragment (max_vec already scaled). Folds the QK scale into the max-subtract so exp2 stays plain and col_max ran in place on raw att. See seed_tile0_scaled.

lazy_rescale_decision​

def lazy_rescale_decision[att_full_dtype: DType](mut self, mut o_reg: TileTensor[DType.float32, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut att_bf16_full: TileTensor[att_full_dtype, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], threshold: Float32) -> Bool

Lazy-rescale decision for main-loop C2/C6.

Returns True iff the running max grew by more than threshold log2 units in this cluster, in which case scale_vec = exp2(prev - new) was applied to o_reg and max_vec_prev was shadow-updated. On the skip path (most clusters), rolls back max_vec to max_vec_prev and resets scale_vec to 1.

Wave-AND reduce via 64-bit ballot against full-exec mask (attend_ker always runs all 64 lanes active).

SCALE_VEC INVARIANT: scale_vec is exactly 1 whenever no rescale fired in the most recent C2/C6. The else-branch reset is load-bearing β€” without it, a stale scale_vec β‰ˆ 1e-38 (from a non-Causal mask's sentinel-driven huge initial growth, where math_exp2(-10_000) = 1.18e-38 is the smallest float32 normal and does NOT flush to 0) gets re-applied 3Γ— in the epilogue tail clusters β†’ norm_vec flushes to 0 β†’ final divide produces Inf.

Returns:

Bool

update_scale_unconditional​

def update_scale_unconditional(mut self)

UNCONDITIONAL rescale: scale_vec = exp2(max_prev - max_new), then max_vec_prev = max_vec.

Used by epilogue tail/full softmax and by _pv_whole_with_partial_softmax where rescale always fires (no lazy-rescale skip). Caller is responsible for the rescale_output(o_reg) step AFTER any IGLP barriers β€” this method does not touch o_reg so the cluster fn can interleave the rescale with PV MFMAs.

rescale_output​

def rescale_output(mut self, mut o_reg: TileTensor[DType.float32, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

o_reg *= scale_vec per element.

Used by _pv_whole_with_partial_softmax (after the IGLP barrier separating PV MFMAs from VALU work) and the final Epi-C10 step before the divide.

apply_norm_rescale_if_pending​

def apply_norm_rescale_if_pending(mut self, pending_scale: Bool)

if pending_scale: norm_vec *= scale_vec.

Used by the C0/C4 tail softmax to roll forward a lazy rescale that fired in the previous C2/C6.

apply_unconditional_norm_rescale​

def apply_unconditional_norm_rescale(mut self)

UNCONDITIONAL norm_vec *= scale_vec.

Used by the epilogue tail/full softmax. Relies on the scale_vec == 1 invariant when no rescale fired (lazy_rescale_decision maintains this on the skip branch).

col_sum_acc​

def col_sum_acc[layout: TensorLayout, //](mut self, att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

Running denominator: norm_vec += sum(att_block, axis=row).

normalize_output​

def normalize_output(mut self, mut o_reg: TileTensor[DType.float32, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

Final o_reg /= norm_vec in place.

Avoids materializing a second FP32 O_LAYOUT tile (64 VGPRs/lane β€” the combined live set with o_reg pushed FP8 KV=128 over the 128 VGPR/thread cap and spilled 9 VGPR-equivalents to scratch in Epi-C12). Used at the very end of the kernel.