Skip to main content

Mojo module

hk_mha_softmax

Softmax primitives for HKMhaPrefill: cross-lane reductions and column-broadcast ops over col_l rt_32x32 register tiles.

Each column of an rt_32x32 is held redundantly across two half-warps (lanes [0, 32) and [32, 64)). Per-column reductions combine in-lane via SIMD.reduce_* and then across half-warps via permlane_swap[32] β€” a single-cycle DPP-style swap. Using stdlib's lane_group_reduce here would lower to ds_bpermute_b32 (LDS-routed), so we go through permlane_swap directly.

Functions​

  • ​col_max: Computes col_accum[j] = max(src[*, j, *]).
  • ​col_max_acc: Running-max for online softmax: col_accum[j] = max(src_accum[j], max(src[*, j, *])).
  • ​col_sum_acc: Running-norm for online softmax: col_accum[j] = src_accum[j] + sum(src[*, j, *]).
  • ​div_col: dst[gr, gc] = src[gr, gc] / vec[gc] (final o_reg / norm_vec).
  • ​mul_col_inplace: dst[gr, gc] *= vec[gc] in place β€” lazy-rescale step of online softmax (o_reg *= exp2(max_prev - max_new) when the running max grows).
  • ​rv_all_below: Returns wave-uniform True iff every lane satisfies max_new - max_prev <= threshold. Wave AND-reduce via a 64-bit ballot compared against the full-exec mask (attend_ker always runs all 64 lanes active). Used by the lazy-rescale skip path.
  • ​sub_col_inplace: dst[gr, gc] -= vec[gc] in place β€” att_block - max_vec before exp2 in online softmax.