Mojo module
hk_mha_mask
Generic mask functor application for the HK MHA att_block.
apply_mask_to_att_block is the kernel-facing entry. It accepts any
MHAMask and the per-tile TileMaskStatus returned by
mask_functor.status(...), and rewrites the per-lane FP32 fragment
in place. The comptime dispatcher inside picks one of three paths:
NullMaskβ comptime-elided (no-op). The status will beNO_MASKso callers don't even reach this path in practice.CausalMaskβ the 16-wide SIMD fast path (onev_cmp+ onev_cndmaskper stripe), generalized forstart_posso the causal cap moves with the cache start position.- Anything else (
SlidingWindowCausalMask,ChunkedCausalMask,MaterializedMask, fused combinations) β per-element loop over the 16 fragment slots, callingmask_functor.mask(coord, score)with the global(seq, head, q_idx, k_idx)coord. Compiler inlines and (often) re-vectorizes; themask_op.mojoproduction AMD path uses the same shape.
Per-element row-within-tile mapping comes from the
v_mfma_f32_32x32x16_bf16 accumulator fragment geometry; see
MhaMmaOp.ACC_ROW_OFFSETS_32x32.
Functionsβ
- β
apply_mask_to_att_block: Generic mask functor application for an HKatt_blocktile.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!