Skip to main content

Mojo module

hk_mha_mask

Generic mask functor application for the HK MHA att_block.

apply_mask_to_att_block is the kernel-facing entry. It accepts any MHAMask and the per-tile TileMaskStatus returned by mask_functor.status(...), and rewrites the per-lane FP32 fragment in place. The comptime dispatcher inside picks one of three paths:

  • NullMask β€” comptime-elided (no-op). The status will be NO_MASK so callers don't even reach this path in practice.
  • CausalMask β€” the 16-wide SIMD fast path (one v_cmp + one v_cndmask per stripe), generalized for start_pos so the causal cap moves with the cache start position.
  • Anything else (SlidingWindowCausalMask, ChunkedCausalMask, MaterializedMask, fused combinations) β€” per-element loop over the 16 fragment slots, calling mask_functor.mask(coord, score) with the global (seq, head, q_idx, k_idx) coord. Compiler inlines and (often) re-vectorizes; the mask_op.mojo production AMD path uses the same shape.

Per-element row-within-tile mapping comes from the v_mfma_f32_32x32x16_bf16 accumulator fragment geometry; see MhaMmaOp.ACC_ROW_OFFSETS_32x32.

Functions​