Skip to main content

Mojo trait

MHAMask

The MHAMask trait describes masks for MHA kernels, such as the causal mask.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

apply_log2e_after_mask

comptime apply_log2e_after_mask

Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?

check_mask_during_decoding

comptime check_mask_during_decoding

Should we check the mask during decoding, or should we assume that it does not return FULL_MASK?

device_type

comptime device_type

Indicate the type being used on accelerator devices.

mask_out_of_bound

comptime mask_out_of_bound

mask_safe_out_of_bounds

comptime mask_safe_out_of_bounds

Is the mask safe to read out of bounds?

Required methods

__init__

__init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • copy (_Self): The value to copy.

Returns:

_Self

__init__(out self: _Self, *, deinit take: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • take (_Self): The value to move.

Returns:

_Self

mask

mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]

Return mask vector at given coordinates.

Arguments: coord is (seq_id, head, q_idx, k_idx) score_vec is at coord of the score matrix

The functor could capture an mask tensor and add to the score e.g. Replit.

Returns:

SIMD

status

status[*, element_type: DType = DType.uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus

Given a tile's index range, return its masking status.

Returns:

TileMaskStatus

start_column

start_column[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32) -> UInt32

Returns the first column for which this mask does not return TileMaskStatus.FULL_MASK. This may not be a multiple of BN, in which case iterating using start_column and masked_set_ends will not necessarily produce the same set or number of iterations as iterating from 0 and checking status to skip. The return value of total_iters should be less than or equal to the number of non-skipped iterations. The practical consequence is that all warp group specializations within a kernel that loop over columns need to be in agreement. Either they all loop over all columns and check status to skip, or they loop using the masked_set_ends.

Returns:

UInt32

total_iters

total_iters[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32

The total number of column iterations for which this mask returns either TileMaskStatus.NO_MASK' or 'TileMaskStatus.PARTIAL_MASK'. This is to be used by warp specializations that do not need to use kv_row`.

Returns:

UInt32

count_nonfull_sets

static count_nonfull_sets(BM: Int, BN: Int) -> Int

The number of blocks that are all partial-masks or not masked.

Returns:

Int

masked_set_ends

masked_set_ends[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]

For each set of iterations in nonfull_sets, indicate the end idx belonging to that set (i.e., the last idx would be end - 1). Note that the final masked_set_ends may not necessarily equal total_iters, if we have UNKNOWN_MASKs. In case of UNKNOWN_MASKs, masked_set_ends with tile-skipping must be used to have the correct kv_row values at each iteration.

Returns:

StaticTuple

last_masked_set_end

last_masked_set_end[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32

Equivalent to masked_set_ends[BM,BN,page_size](row, num_cols)[-1].

Returns:

UInt32

nonfull_sets

static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]

For each set of iterations that are either partially masked or not masked, this indicates the mask status. UNKNOWN_MASK here is an indicator meaning that we should check the status at runtime. It is semantically equivalent to partial, but with the optimization hint that it's worth checking on each iteration at runtime for FULL_MASK (in which case we can skip the tile) or NO_MASK (in which case we can unswitch and avoid masking in an inner loop).

Returns:

StaticTuple

mask_strategies

static mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]

For each set of iterations that are either partially masked or not masked, this indicates the MaskStrategy to use.

Returns:

StaticTuple

name

static name() -> String

Returns:

String

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Returns:

_Self: A copy of this value.

Was this page helpful?