Mojo trait
MHAMask
The MHAMask trait describes masks for MHA kernels, such as the causal mask.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
apply_log2e_after_mask
comptime apply_log2e_after_mask
Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?
check_mask_during_decoding
comptime check_mask_during_decoding
Should we check the mask during decoding, or should we assume that it does not return FULL_MASK?
device_type
comptime device_type
Indicate the type being used on accelerator devices.
mask_out_of_bound
comptime mask_out_of_bound
mask_safe_out_of_bounds
comptime mask_safe_out_of_bounds
Is the mask safe to read out of bounds?
Required methods
__init__
__init__(out self: _Self, *, copy: _Self)
Create a new instance of the value by copying an existing one.
Args:
- copy (
_Self): The value to copy.
Returns:
_Self
__init__(out self: _Self, *, deinit take: _Self)
Create a new instance of the value by moving the value of another.
Args:
- take (
_Self): The value to move.
Returns:
_Self
mask
mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]
Return mask vector at given coordinates.
Arguments:
coord is (seq_id, head, q_idx, k_idx)
score_vec is at coord of the score matrix
The functor could capture an mask tensor and add to the score e.g. Replit.
Returns:
status
status[*, element_type: DType = DType.uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus
Given a tile's index range, return its masking status.
Returns:
TileMaskStatus
start_column
start_column[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32) -> UInt32
Returns the first column for which this mask does not return TileMaskStatus.FULL_MASK. This may not be a multiple of BN, in which case iterating using start_column and masked_set_ends will not necessarily produce the same set or number of iterations as iterating from 0 and checking status to skip. The return value of total_iters should be less than or equal to the number of non-skipped iterations. The practical consequence is that all warp group specializations within a kernel that loop over columns need to be in agreement. Either they all loop over all columns and check status to skip, or they loop using the masked_set_ends.
Returns:
total_iters
total_iters[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32
The total number of column iterations for which this mask returns either TileMaskStatus.NO_MASK' or 'TileMaskStatus.PARTIAL_MASK'. This is to be used by warp specializations that do not need to use kv_row`.
Returns:
count_nonfull_sets
static count_nonfull_sets(BM: Int, BN: Int) -> Int
The number of blocks that are all partial-masks or not masked.
Returns:
masked_set_ends
masked_set_ends[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]
For each set of iterations in nonfull_sets, indicate the end idx belonging to that set (i.e., the last idx would be end - 1). Note that the final masked_set_ends may not necessarily equal total_iters, if we have UNKNOWN_MASKs. In case of UNKNOWN_MASKs, masked_set_ends with tile-skipping must be used to have the correct kv_row values at each iteration.
Returns:
last_masked_set_end
last_masked_set_end[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32
Equivalent to masked_set_ends[BM,BN,page_size](row, num_cols)[-1].
Returns:
nonfull_sets
static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]
For each set of iterations that are either partially masked or not masked, this indicates the mask status. UNKNOWN_MASK here is an indicator meaning that we should check the status at runtime. It is semantically equivalent to partial, but with the optimization hint that it's worth checking on each iteration at runtime for FULL_MASK (in which case we can skip the tile) or NO_MASK (in which case we can unswitch and avoid masking in an inner loop).
Returns:
mask_strategies
static mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]
For each set of iterations that are either partially masked or not masked, this indicates the MaskStrategy to use.
Returns:
name
get_type_name
static get_type_name() -> String
Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.
Returns:
String: The host type's name.
Provided methods
copy
copy(self: _Self) -> _Self
Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.
Returns:
_Self: A copy of this value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!