For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
OnlineSoftmax
struct OnlineSoftmax[att_dtype: DType = DType.float32]
Online softmax row-state bundle for MhaPrefillV2.
Parametrized on att_dtype β the dtype of the att_block
RegTile this state operates against (FP32 for the BF16 prefill
path and FP8 + KV<128; FP16 for FP8 + KV>=128 per
_SOFTMAX_DTYPE in MhaPrefillV2 / MlaPrefillV2). All
att_block-touching
methods bind to Self.att_dtype so the type checker rejects
mismatched dtypes at the call site instead of silently coercing.
The four state scalars stay Float32 regardless β see
"Accumulator dtype rationale" below.
Owns the four row-state scalars maintained by the FlashAttention-2
online-softmax recurrence as direct Float32 fields:
max_vecβ running rowmax (in log2 units; the reference prescales Q byscale * log2(e)so att values are already in log2 units).max_vec_prevβ rowmax from the previous tile. Used by the lazy-rescale comparison and the unconditional rescale'sexp2(prev - new). Shadow-updated tomax_vecafter each consumed cluster.norm_vecβ running denominator (exp-sum so far). Consumed at Epi-C12'so_reg /= norm_vec.scale_vecβ pending rescale factorexp2(max_prev - max_new). Conditionally applied too_regduring lazy-rescale (main loop) or unconditionally during the epilogue tails. Reset to 1 when no rescale fired sonorm_vec *= scale_vecis a safe identity.
Each field is a per-lane FP32 scalar (1 VGPR/lane). The col_l rt_32x32 topology gives each lane ownership of one column of the fragment = one Q row in the warp's stripe, so per-lane scalar state is the natural representation. (Each column is held redundantly across the two half-warps sharing it β both lanes store their own copy of the identical reduced value.)
Lifetime (MhaPrefillV2):
max_vec,max_vec_prev,scale_vecβ prologue β Epi-C10 (last touched by_full_softmax_unconditional+ the finalrescale_output(o_reg)).norm_vecβ prologue β Epi-C12 (consumed bynormalize_output(o_reg), three clusters after the other three die).
norm_vec outliving the other three by two clusters costs 1 VGPR
over Epi-C11/C12; not register-pressure-constrained there.
Accumulator dtype rationale (why state stays Float32 regardless
of att_dtype): the four scalars are per-lane (1 VGPR/lane each),
so narrowing to FP16 saves zero VGPRs (registers are dword-granular;
a half-VGPR scalar costs pack/unpack ALU on every read). norm_vec
sums up to ~seq_len terms of exp2(att - max) β [0, 1]; at
seq=8192 the sum reaches ~2^13, leaving little FP16 headroom.
Hardware v_exp_f32 / v_rcp_f32 is the FP32 path. Narrowing
att_dtype to FP16 only pays off on the larger att_block tile
storage, not on the recurrence scalars.
Fieldsβ
- βmax_vec (
Float32): - βmax_vec_prev (
Float32): - βnorm_vec (
Float32): - βscale_vec (
Float32):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable
Methodsβ
__init__β
def __init__(out self)
No-sink init: max_vec/max_vec_prev/norm_vec to 0, scale_vec to 1 so the epilogue's unconditional norm_vec *= scale_vec is a safe no-op when no rescale fired.
reseed_with_sinkβ
def reseed_with_sink(mut self, sw_log2: Float32)
Re-init for the sink path: pre-seed the recurrence with the virtual sink token's contribution.
max_vec = max_vec_prev = log2e * sink_weight keeps the rowmax
in log2 units (the reference prescales Q by scale * log2e, so
att values are in log2 units). norm_vec = 1 reflects the
virtual sink's exp2(score - max) = exp2(0) = 1 contribution.
Subsequent tiles update through the normal recurrence; the sink
is rescaled implicitly as the running max grows. scale_vec
stays at the 1 set by __init__.
seed_tile0β
def seed_tile0[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
Prologue tile-0 partial softmax setup.
Three-step composite: max_vec = col_max(att_block),
max_vec_prev = max_vec, att_block -= max_vec. Called once
after the first QK + mask, before the prologue's first-half
exp2. There's no col_max_acc because there's no prior
rowmax yet (no-sink path) or the sink contribution is already
seeded into max_vec (sink path).
col_max_accβ
def col_max_acc[layout: TensorLayout, //](mut self, att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
Running rowmax: max_vec = max(max_vec_prev, col_max(att_block)).
Caller maintains the max_vec_prev = max_vec shadow-write
separately (via lazy_rescale_decision or
update_scale_unconditional) β splitting the shadow-write out
lets the cluster fns interpose IGLP barriers between the
rowmax update and the scale update without dragging
max_vec_prev along.
sub_maxβ
def sub_max[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
att_block -= max_vec per element. Prepares att_block for the subsequent exp2_inplace_range call.
seed_tile0_scaledβ
def seed_tile0_scaled[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], log2_scale: Float32)
Scale-folded seed_tile0 for the FP32-in-place sequential cadence (_FP32_SOFTMAX_SCORES). att_block arrives RAW (un-scaled QK); log2_scale = scale * log2(e).
max_vec = log2_scale * col_max(raw) (running max kept in the
SCALED log2 domain, so update_scale_unconditional's
exp2(prev - new) rescale is byte-identical to the non-folded
path), max_vec_prev = max_vec, then att = log2_scale*att - max_vec via one fused v_pk_fma per fragment. The col_max
runs in place on the raw 64-VGPR att (no transient scaled copy)
and the scale never touches the exp2 critical path. Math
identical: scale > 0 β max(scaleΒ·x) = scaleΒ·max(x).
Uses _col_max_scalar_v3max (v_max3_f32-folded, FP32-scores-
only) instead of _col_max_scalar β same column max, only the
in-lane fold grouping changes so ISel emits v_max3_f32.
col_max_acc_scaledβ
def col_max_acc_scaled[layout: TensorLayout, //](mut self, att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], log2_scale: Float32)
Scale-folded col_max_acc: max_vec = max(max_prev, log2_scale * col_max(raw att)). Keeps the running max in the scaled log2 domain while att_block stays RAW. See seed_tile0_scaled.
Uses _col_max_scalar_v3max (v_max3_f32-folded) for the raw
column max. The running-max fold-in stays the OUTER max(...)
here rather than a seed into the chain: max_prev is in the
SCALED log2 domain but the chain reduces RAW scores, so seeding
the chain with max_prev would compare across domains. The
log2_scale * raw_max then max(max_prev, ...) ordering keeps
the rescale exp2(prev - new) byte-identical to the non-folded
path; the outer max is one residual v_max_f32 per tile.
sub_max_scaledβ
def sub_max_scaled[layout: TensorLayout, //](mut self, mut att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], log2_scale: Float32)
Scale-folded sub_max: att = log2_scale*att - max_vec via one fused v_pk_fma per fragment (max_vec already scaled). Folds the QK scale into the max-subtract so exp2 stays plain and col_max ran in place on raw att. See seed_tile0_scaled.
lazy_rescale_decisionβ
def lazy_rescale_decision[att_full_dtype: DType](mut self, mut o_reg: TileTensor[DType.float32, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut att_bf16_full: TileTensor[att_full_dtype, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], threshold: Float32) -> Bool
Lazy-rescale decision for main-loop C2/C6.
Returns True iff the running max grew by more than
threshold log2 units in this cluster, in which case
scale_vec = exp2(prev - new) was applied to o_reg and
max_vec_prev was shadow-updated. On the skip path (most
clusters), rolls back max_vec to max_vec_prev and resets
scale_vec to 1.
Wave-AND reduce via 64-bit ballot against full-exec mask
(attend_ker always runs all 64 lanes active).
SCALE_VEC INVARIANT: scale_vec is exactly 1 whenever no
rescale fired in the most recent C2/C6. The else-branch reset
is load-bearing β without it, a stale scale_vec β 1e-38
(from a non-Causal mask's sentinel-driven huge initial growth,
where math_exp2(-10_000) = 1.18e-38 is the smallest float32
normal and does NOT flush to 0) gets re-applied 3Γ in the
epilogue tail clusters β norm_vec flushes to 0 β final
divide produces Inf.
Returns:
update_scale_unconditionalβ
def update_scale_unconditional(mut self)
UNCONDITIONAL rescale: scale_vec = exp2(max_prev - max_new), then max_vec_prev = max_vec.
Used by epilogue tail/full softmax and by
_pv_whole_with_partial_softmax where rescale always fires
(no lazy-rescale skip). Caller is responsible for the
rescale_output(o_reg) step AFTER any IGLP barriers β this
method does not touch o_reg so the cluster fn can interleave
the rescale with PV MFMAs.
rescale_outputβ
def rescale_output(mut self, mut o_reg: TileTensor[DType.float32, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
o_reg *= scale_vec per element.
Used by _pv_whole_with_partial_softmax (after the IGLP
barrier separating PV MFMAs from VALU work) and the final
Epi-C10 step before the divide.
apply_norm_rescale_if_pendingβ
def apply_norm_rescale_if_pending(mut self, pending_scale: Bool)
if pending_scale: norm_vec *= scale_vec.
Used by the C0/C4 tail softmax to roll forward a lazy rescale that fired in the previous C2/C6.
apply_unconditional_norm_rescaleβ
def apply_unconditional_norm_rescale(mut self)
UNCONDITIONAL norm_vec *= scale_vec.
Used by the epilogue tail/full softmax. Relies on the
scale_vec == 1 invariant when no rescale fired
(lazy_rescale_decision maintains this on the skip branch).
col_sum_accβ
def col_sum_acc[layout: TensorLayout, //](mut self, att_block: TileTensor[att_dtype, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
Running denominator: norm_vec += sum(att_block, axis=row).
normalize_outputβ
def normalize_output(mut self, mut o_reg: TileTensor[DType.float32, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
Final o_reg /= norm_vec in place.
Avoids materializing a second FP32 O_LAYOUT tile (64 VGPRs/lane
β the combined live set with o_reg pushed FP8 KV=128 over
the 128 VGPR/thread cap and spilled 9 VGPR-equivalents to
scratch in Epi-C12). Used at the very end of the kernel.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!