For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo trait
MHAMask
The MHAMask trait describes masks for MHA kernels, such as the causal mask.
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
apply_log2e_after_maskβ
comptime apply_log2e_after_mask
Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?
check_mask_during_decodingβ
comptime check_mask_during_decoding
Should we check the mask during decoding, or should we assume that it does not return FULL_MASK?
device_typeβ
comptime device_type
Indicate the type being used on accelerator devices.
mask_out_of_boundβ
comptime mask_out_of_bound
mask_safe_out_of_boundsβ
comptime mask_safe_out_of_bounds
Is the mask safe to read out of bounds?
Required methodsβ
__init__β
def __init__(out self: _Self, *, copy: _Self)
Create a new instance of the value by copying an existing one.
Args:
- βcopy (
_Self): The value to copy.
Returns:
_Self
def __init__(out self: _Self, *, deinit move: _Self)
Create a new instance of the value by moving the value of another.
Args:
- βmove (
_Self): The value to move.
Returns:
_Self
maskβ
def mask[dtype: DType, width: SIMDSize, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[Int(4), element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]
Return mask vector at given coordinates.
Arguments:
coord is (seq_id, head, q_idx, k_idx)
score_vec is at coord of the score matrix
The functor could capture an mask tensor and add to the score e.g. Replit.
Returns:
statusβ
def status[*, element_type: DType = DType.uint32](self: _Self, seq_id: UInt32, tile_offset: IndexList[Int(2), element_type=element_type], tile_size: IndexList[Int(2), element_type=element_type]) -> TileMaskStatus
Given a tile's index range, return its masking status.
seq_id identifies the sequence/batch this tile belongs to and is
used by masks (e.g., CausalPaddingMask) whose status depends on
per-sequence state. Implementations that don't need it should ignore
it; the unused argument will be DCE'd.
Returns:
TileMaskStatus
start_columnβ
def start_column[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32) -> UInt32
Returns the first column for which this mask does not return TileMaskStatus.FULL_MASK. This may not be a multiple of BN, in which case iterating using start_column and masked_set_ends will not necessarily produce the same set or number of iterations as iterating from 0 and checking status to skip. The return value of total_iters should be less than or equal to the number of non-skipped iterations. The practical consequence is that all warp group specializations within a kernel that loop over columns need to be in agreement. Either they all loop over all columns and check status to skip, or they loop using the masked_set_ends.
Returns:
start_column_alignmentβ
static def start_column_alignment[BM: Int, BN: Int, page_size: Int]() -> Int
The largest power of 2, capped at BN, that divides every base_kv_row = start_column + k*BN produced by BN-stride mask-driven iteration. Callers pass this directly as base_alignment to PagedKVCache.populate, which uses it to pick the largest legal SIMD chunk for its LUT vector load.
Implementations must return a value that already divides BN
(equivalently, the value must equal
gcd(natural_alignment, BN)). For an implementation whose
natural start_column alignment is a power of 2 less than or
equal to BN, this is automatic. An implementation whose
natural alignment doesn't divide BN must wrap its return in
gcd(..., BN) itself.
Returns:
total_itersβ
def total_iters[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32
The total number of column iterations for which this mask returns either TileMaskStatus.NO_MASK' or 'TileMaskStatus.PARTIAL_MASK'. This is to be used by warp specializations that do not need to use kv_row`.
Returns:
count_nonfull_setsβ
static def count_nonfull_sets(BM: Int, BN: Int) -> Int
The number of blocks that are all partial-masks or not masked.
Returns:
masked_set_endsβ
def masked_set_ends[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]
For each set of iterations in nonfull_sets, indicate the end idx belonging to that set (i.e., the last idx would be end - 1). Note that the final masked_set_ends may not necessarily equal total_iters, if we have UNKNOWN_MASKs. In case of UNKNOWN_MASKs, masked_set_ends with tile-skipping must be used to have the correct kv_row values at each iteration.
Returns:
last_masked_set_endβ
def last_masked_set_end[BM: Int, BN: Int, page_size: Int](self: _Self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32
Equivalent to masked_set_ends[BM,BN,page_size](seq_id, row, num_cols)[-1].
Returns:
nonfull_setsβ
static def nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]
For each set of iterations that are either partially masked or not masked, this indicates the mask status. UNKNOWN_MASK here is an indicator meaning that we should check the status at runtime. It is semantically equivalent to partial, but with the optimization hint that it's worth checking on each iteration at runtime for FULL_MASK (in which case we can skip the tile) or NO_MASK (in which case we can unswitch and avoid masking in an inner loop).
Returns:
mask_strategiesβ
static def mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, _Self.count_nonfull_sets(::SIMD[::DType(int), ::SIMDSize(1)],::SIMD[::DType(int), ::SIMDSize(1)])(BM, BN)]
For each set of iterations that are either partially masked or not masked, this indicates the MaskStrategy to use.
Returns:
mask_bitsβ
def mask_bits(self: _Self, seq_id: UInt32, score_row: Int32, col_start: Int32, num_keys: Int32) -> UInt32
Returns a 32-bit visibility bitmask for a column batch.
Bit i (0..31) is 1 iff the column col_start + i is visible at
(seq_id, score_row) for this mask. Called by SM100 apply_mask
once per 32-col batch when the mask's mask_strategies returns
MaskStrategy.BITMASK. Masks that don't use BITMASK should return
0xFFFF_FFFF (no constraint); the result is unused in that case.
Args:
- βseq_id (
UInt32): Per-sequence batch index (e.g. forCausalPaddingMask). - βscore_row (
Int32): Global query row (the q index in the attention score). - βcol_start (
Int32): Global key index of bit 0 in this batch. - βnum_keys (
Int32): Kernel cache length (upper bound on visible keys).
Returns:
sliding_window_sizeβ
static def sliding_window_size() -> Int
Returns the sliding window lower-bound offset, or 0 if unbounded.
For SlidingWindowCausalMask, returns window_size. MLA decode
kernels read this to recover the window size in places where the
struct's parametric window_size is not accessible through the
trait surface.
Returns:
nameβ
get_type_nameβ
static def get_type_name() -> String
Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.
Returns:
String: The host type's name.
Provided methodsβ
copyβ
def copy(self: _Self) -> _Self
Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.
Overriding this method is not allowed.
Returns:
_Self: A copy of this value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!