Mojo module
hk_mha_softmax
Softmax primitives for HKMhaPrefill: cross-lane reductions and column-broadcast ops over col_l rt_32x32 register tiles.
Each column of an rt_32x32 is held redundantly across two half-warps
(lanes [0, 32) and [32, 64)). Per-column reductions combine
in-lane via SIMD.reduce_* and then across half-warps via
permlane_swap[32] β a single-cycle DPP-style swap. Using stdlib's
lane_group_reduce here would lower to ds_bpermute_b32 (LDS-routed),
so we go through permlane_swap directly.
Functionsβ
- β
col_max: Computescol_accum[j] = max(src[*, j, *]). - β
col_max_acc: Running-max for online softmax:col_accum[j] = max(src_accum[j], max(src[*, j, *])). - β
col_sum_acc: Running-norm for online softmax:col_accum[j] = src_accum[j] + sum(src[*, j, *]). - β
div_col:dst[gr, gc] = src[gr, gc] / vec[gc](finalo_reg / norm_vec). - β
mul_col_inplace:dst[gr, gc] *= vec[gc]in place β lazy-rescale step of online softmax (o_reg *= exp2(max_prev - max_new)when the running max grows). - β
rv_all_below: Returns wave-uniform True iff every lane satisfiesmax_new - max_prev <= threshold. Wave AND-reduce via a 64-bit ballot compared against the full-exec mask (attend_keralways runs all 64 lanes active). Used by the lazy-rescale skip path. - β
sub_col_inplace:dst[gr, gc] -= vec[gc]in place βatt_block - max_vecbeforeexp2in online softmax.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!