For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MaskApplier
struct MaskApplier[mask_t: MHAMask, //, Q_BLOCK_SIZE: Int, KV_BLOCK_SIZE: Int]
Mask functor + comptime block-size bundle for MhaPrefillV2.
Owns the runtime mask functor and exposes a single apply() entry
that comptime-dispatches over mask_t. The dispatch consolidates
what was previously a two-level hop (MhaPrefillV2._maybe_apply_mask
β apply_mask_to_att_block); after @always_inline both layers
fold into the same set of branches, so the consolidated form is
codegen-identical while being one level less to follow.
The struct is light: a single mask_functor field. The comptime
block sizes are parameters (not fields) so the dispatch arithmetic
folds to literal Ints at instantiation.
Parametersβ
- βmask_t (
MHAMask): The mask functor type (anyMHAMask). - βQ_BLOCK_SIZE (
Int): Q rows per tile (MhaConfigV2.q_block_size). - βKV_BLOCK_SIZE (
Int): K rows per tile (MhaConfigV2.kv_block).
Fieldsβ
- βmask_functor (
mask_t):
Implemented traitsβ
Methodsβ
__init__β
def __init__(out self, mask_functor: mask_t)
Bundle the mask functor. Comptime block sizes come from the struct's parameters.
applyβ
def apply[T_dst: DType, layout: TensorLayout, //](self, mut att_block: TileTensor[T_dst, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], q_tile_idx: Int, k_tile_idx: Int, start_pos: Int, head_idx: UInt32, batch_idx: UInt32, lane: Int, num_keys: Int = -1)
Mask an att_block tile, comptime-dispatching on mask_t.
NullMaskβ key-bound fast path (_apply_kbound_mask_fast) on tiles whose K range extends pastnum_keys(partial last tile + phantom even-parity tile); fully-valid tiles are a branch-cheap no-op.CausalMaskβ runtimeq_start_pos < kv_end_posshortcut (most non-trailing tiles bypass mask work entirely), then the 16-wide SIMD fast path withstart_posshift.- Any other
MHAMask(SlidingWindowCausalMask,ChunkedCausalMask,MaterializedMask, fused combinations) β runtimemask_functor.status(...)dispatch overNO_MASK(return),FULL_MASK(fill-inf, subsequentexp2zeros every entry),PARTIAL(per-elementmask_functor.mask(coord, score)loop over the 16 fragment slots). The runtimestatus()call + enum branching adds a few SGPR ops per tile; the per-tile cost is acceptable for masks without a comptime-specialized fast path. Production callers: Gemma-3 (sliding window), Gemma-4 (chunked).
Args:
- βatt_block (
TileTensor[T_dst, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL]): Attention block tile (mutated in place). - βq_tile_idx (
Int): Absolute Q tile index (this warp's first row /Q_BLOCK_SIZE; usuallyblock_tile_idx * NUM_WARPS + w_id). - βk_tile_idx (
Int): Absolute K tile index. - βstart_pos (
Int): KV-cache start position. Shifts the Q absolute position bystart_posto account for cache reuse. - βhead_idx (
UInt32): Q head index (block_idx.xin the kernel). - βbatch_idx (
UInt32): Batch index (block_idx.zin the kernel). - βlane (
Int):lane_id()cast toInt. - βnum_keys (
Int): Runtime K/V sequence length for the NullMask kbound;-1(the default) disables it β MLA callers whose masks don't need the bound omit it.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!