IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mha_mask

Mask functor application for the MhaPrefillV2 att_block.

MaskApplier[mask_t, Q_BLOCK_SIZE, KV_BLOCK_SIZE] (below) bundles the mask functor with the comptime block sizes and exposes a single apply() entry that comptime-dispatches over the mask type:

  • NullMask β€” comptime-elided no-op. The mask trait would always report NO_MASK, so the entry is statically dead.
  • CausalMask β€” 16-wide SIMD fast path (one v_cmp + one v_cndmask per stripe), generalized for start_pos so the causal cap moves with the cache start position. Gated on the runtime q_start_pos < kv_end_pos shortcut so fully-unmasked tiles bypass the work entirely.
  • Anything else (SlidingWindowCausalMask, ChunkedCausalMask, MaterializedMask, fused combinations) β€” runtime mask_functor.status(...) dispatch over NO_MASK (return), FULL_MASK (fill -inf), PARTIAL (per-element loop calling mask_functor.mask(coord, score) over the 16 fragment slots).

Per-element row-within-tile mapping comes from the v_mfma_f32_32x32x16_bf16 accumulator fragment geometry; see MhaMmaOp.ACC_ROW_OFFSETS_32x32.

Structs​