IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MaskApplier

struct MaskApplier[mask_t: MHAMask, //, Q_BLOCK_SIZE: Int, KV_BLOCK_SIZE: Int]

Mask functor + comptime block-size bundle for MhaPrefillV2.

Owns the runtime mask functor and exposes a single apply() entry that comptime-dispatches over mask_t. The dispatch consolidates what was previously a two-level hop (MhaPrefillV2._maybe_apply_mask β†’ apply_mask_to_att_block); after @always_inline both layers fold into the same set of branches, so the consolidated form is codegen-identical while being one level less to follow.

The struct is light: a single mask_functor field. The comptime block sizes are parameters (not fields) so the dispatch arithmetic folds to literal Ints at instantiation.

Parameters​

  • ​mask_t (MHAMask): The mask functor type (any MHAMask).
  • ​Q_BLOCK_SIZE (Int): Q rows per tile (MhaConfigV2.q_block_size).
  • ​KV_BLOCK_SIZE (Int): K rows per tile (MhaConfigV2.kv_block).

Fields​

  • ​mask_functor (mask_t):

Implemented traits​

AnyType, ImplicitlyDeletable

Methods​

__init__​

def __init__(out self, mask_functor: mask_t)

Bundle the mask functor. Comptime block sizes come from the struct's parameters.

apply​

def apply[T_dst: DType, layout: TensorLayout, //](self, mut att_block: TileTensor[T_dst, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], q_tile_idx: Int, k_tile_idx: Int, start_pos: Int, head_idx: UInt32, batch_idx: UInt32, lane: Int, num_keys: Int = -1)

Mask an att_block tile, comptime-dispatching on mask_t.

  • NullMask β†’ key-bound fast path (_apply_kbound_mask_fast) on tiles whose K range extends past num_keys (partial last tile + phantom even-parity tile); fully-valid tiles are a branch-cheap no-op.
  • CausalMask β†’ runtime q_start_pos < kv_end_pos shortcut (most non-trailing tiles bypass mask work entirely), then the 16-wide SIMD fast path with start_pos shift.
  • Any other MHAMask (SlidingWindowCausalMask, ChunkedCausalMask, MaterializedMask, fused combinations) β†’ runtime mask_functor.status(...) dispatch over NO_MASK (return), FULL_MASK (fill -inf, subsequent exp2 zeros every entry), PARTIAL (per-element mask_functor.mask(coord, score) loop over the 16 fragment slots). The runtime status() call + enum branching adds a few SGPR ops per tile; the per-tile cost is acceptable for masks without a comptime-specialized fast path. Production callers: Gemma-3 (sliding window), Gemma-4 (chunked).

Args:

  • ​att_block (TileTensor[T_dst, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL]): Attention block tile (mutated in place).
  • ​q_tile_idx (Int): Absolute Q tile index (this warp's first row / Q_BLOCK_SIZE; usually block_tile_idx * NUM_WARPS + w_id).
  • ​k_tile_idx (Int): Absolute K tile index.
  • ​start_pos (Int): KV-cache start position. Shifts the Q absolute position by start_pos to account for cache reuse.
  • ​head_idx (UInt32): Q head index (block_idx.x in the kernel).
  • ​batch_idx (UInt32): Batch index (block_idx.z in the kernel).
  • ​lane (Int): lane_id() cast to Int.
  • ​num_keys (Int): Runtime K/V sequence length for the NullMask kbound; -1 (the default) disables it β€” MLA callers whose masks don't need the bound omit it.