IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo trait

MHAMask

The MHAMask trait describes masks for MHA kernels, such as the causal mask.

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

apply_log2e_after_mask​

comptime apply_log2e_after_mask

Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?

check_mask_during_decoding​

comptime check_mask_during_decoding

Should we check the mask during decoding, or should we assume that it does not return FULL_MASK?

device_type​

comptime device_type

Indicate the type being used on accelerator devices.

mask_out_of_bound​

comptime mask_out_of_bound

mask_safe_out_of_bounds​

comptime mask_safe_out_of_bounds

Is the mask safe to read out of bounds?

Required methods​

__init__​

def __init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • ​copy (_Self): The value to copy.

Returns:

_Self

def __init__(out self: _Self, *, deinit move: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • ​move (_Self): The value to move.

Returns:

_Self

mask​

def mask[dtype: DType, width: SIMDSize, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[Int(4), element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]

Return mask vector at given coordinates.

Arguments: coord is (seq_id, head, q_idx, k_idx) score_vec is at coord of the score matrix

The functor could capture an mask tensor and add to the score e.g. Replit.

Returns:

SIMD[dtype, width]

status​

def status[*, element_type: DType = DType.uint32](self: _Self, seq_id: UInt32, tile_offset: IndexList[Int(2), element_type=element_type], tile_size: IndexList[Int(2), element_type=element_type]) -> TileMaskStatus

Given a tile's index range, return its masking status.

seq_id identifies the sequence/batch this tile belongs to and is used by masks (e.g., CausalPaddingMask) whose status depends on per-sequence state. Implementations that don't need it should ignore it; the unused argument will be DCE'd.

Returns:

TileMaskStatus

start_column​

def start_column[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32) -> UInt32

Returns the first column for which this mask does not return TileMaskStatus.FULL_MASK. This may not be a multiple of BN, in which case iterating using start_column and masked_set_ends will not necessarily produce the same set or number of iterations as iterating from 0 and checking status to skip. The return value of total_iters should be less than or equal to the number of non-skipped iterations. The practical consequence is that all warp group specializations within a kernel that loop over columns need to be in agreement. Either they all loop over all columns and check status to skip, or they loop using the masked_set_ends.

Returns:

UInt32

start_column_alignment​

static def start_column_alignment[BM: Int, BN: Int, page_size: Int]() -> Int

The largest power of 2, capped at BN, that divides every base_kv_row = start_column + k*BN produced by BN-stride mask-driven iteration. Callers pass this directly as base_alignment to PagedKVCache.populate, which uses it to pick the largest legal SIMD chunk for its LUT vector load.

Implementations must return a value that already divides BN (equivalently, the value must equal gcd(natural_alignment, BN)). For an implementation whose natural start_column alignment is a power of 2 less than or equal to BN, this is automatic. An implementation whose natural alignment doesn't divide BN must wrap its return in gcd(..., BN) itself.

Returns:

Int

total_iters​

def total_iters[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32

The total number of column iterations for which this mask returns either TileMaskStatus.NO_MASK' or 'TileMaskStatus.PARTIAL_MASK'. This is to be used by warp specializations that do not need to use kv_row`.

Returns:

UInt32

count_nonfull_sets​

static def count_nonfull_sets(BM: Int, BN: Int) -> Int

The number of blocks that are all partial-masks or not masked.

Returns:

Int

masked_set_ends​

def masked_set_ends[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]

For each set of iterations in nonfull_sets, indicate the end idx belonging to that set (i.e., the last idx would be end - 1). Note that the final masked_set_ends may not necessarily equal total_iters, if we have UNKNOWN_MASKs. In case of UNKNOWN_MASKs, masked_set_ends with tile-skipping must be used to have the correct kv_row values at each iteration.

Returns:

StaticTuple[UInt32, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]

last_masked_set_end​

def last_masked_set_end[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32

Equivalent to masked_set_ends[BM,BN,page_size](seq_id, row, num_cols)[-1].

Returns:

UInt32

nonfull_sets​

static def nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]

For each set of iterations that are either partially masked or not masked, this indicates the mask status. UNKNOWN_MASK here is an indicator meaning that we should check the status at runtime. It is semantically equivalent to partial, but with the optimization hint that it's worth checking on each iteration at runtime for FULL_MASK (in which case we can skip the tile) or NO_MASK (in which case we can unswitch and avoid masking in an inner loop).

Returns:

StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]

mask_strategies​

static def mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]

For each set of iterations that are either partially masked or not masked, this indicates the MaskStrategy to use.

Returns:

StaticTuple[MaskStrategy, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]

mask_bits​

def mask_bits(self: _Self, seq_id: UInt32, score_row: Int32, col_start: Int32, num_keys: Int32) -> UInt32

Returns a 32-bit visibility bitmask for a column batch.

Bit i (0..31) is 1 iff the column col_start + i is visible at (seq_id, score_row) for this mask. Called by SM100 apply_mask once per 32-col batch when the mask's mask_strategies returns MaskStrategy.BITMASK. Masks that don't use BITMASK should return 0xFFFF_FFFF (no constraint); the result is unused in that case.

Args:

  • ​seq_id (UInt32): Per-sequence batch index (e.g. for CausalPaddingMask).
  • ​score_row (Int32): Global query row (the q index in the attention score).
  • ​col_start (Int32): Global key index of bit 0 in this batch.
  • ​num_keys (Int32): Kernel cache length (upper bound on visible keys).

Returns:

UInt32

sliding_window_size​

static def sliding_window_size() -> Int

Returns the sliding window lower-bound offset, or 0 if unbounded.

For SlidingWindowCausalMask, returns window_size. MLA decode kernels read this to recover the window size in places where the struct's parametric window_size is not accessible through the trait surface.

Returns:

Int

name​

static def name() -> String

Returns:

String

get_type_name​

static def get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods​

copy​

def copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Overriding this method is not allowed.

Returns:

_Self: A copy of this value.