Skip to main content

Mojo trait

MHAMask

The MHAMask trait describes masks for MHA kernels, such as the causal mask.

Implemented traits

AnyType, Copyable, DevicePassable, UnknownDestructibility

Aliases

__copyinit__is_trivial

comptime __copyinit__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __copyinit__ is trivial.

The implementation of __copyinit__ is considered to be trivial if:

  • The struct has a compiler-generated trivial __copyinit__ and all its fields have a trivial __copyinit__ method.

In practice, it means the value can be copied by copying the bits from one location to another without side effects.

__del__is_trivial

comptime __del__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __del__ is trivial.

The implementation of __del__ is considered to be trivial if:

  • The struct has a compiler-generated trivial destructor and all its fields have a trivial __del__ method.

In practice, it means that the __del__ can be considered as no-op.

apply_log2e_after_mask

comptime apply_log2e_after_mask

Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?

check_mask_during_decoding

comptime check_mask_during_decoding

Should we check the mask during decoding, or should we assume that it does not return FULL_MASK?

device_type

comptime device_type

Indicate the type being used on accelerator devices.

mask_out_of_bound

comptime mask_out_of_bound

mask_safe_out_of_bounds

comptime mask_safe_out_of_bounds

Is the mask safe to read out of bounds?

Required methods

__copyinit__

__copyinit__(out self: _Self, existing: _Self, /)

Create a new instance of the value by copying an existing one.

Args:

  • existing (_Self): The value to copy.

Returns:

_Self

mask

mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]

Return mask vector at given coordinates.

Arguments: coord is (seq_id, head, q_idx, k_idx) score_vec is at coord of the score matrix

The functor could capture an mask tensor and add to the score e.g. Replit.

Returns:

SIMD

status

status[*, element_type: DType = DType.uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus

Given a tile's index range, return its masking status.

Returns:

TileMaskStatus

start_column

start_column[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32) -> UInt32

Returns the first column for which this mask does not return TileMaskStatus.FULL_MASK. This may not be a multiple of BN, in which case iterating using start_column and masked_set_ends will not necessarilly produce the same set or number of iterations as iterating from 0 and checking status to skip. The return value of total_iters should be less than or equal to the number of non-skipped iterations. The practical consequence is that all warp group specializations within a kernel that loop over columns need to be in agreement. Either they all loop over all columns and check status to skip, or they loop using the masked_set_ends.

Returns:

UInt32

total_iters

total_iters[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32

The total number of column iterations for which this mask returns either TileMaskStatus.NO_MASK' or 'TileMaskStatus.PARTIAL_MASK'. This is to be used by warp specializations that do not need to use kv_row`.

Returns:

UInt32

count_nonfull_sets

static count_nonfull_sets(BM: Int, BN: Int) -> Int

The number of blocks that are all partial-masks or not masked.

Returns:

Int

masked_set_ends

masked_set_ends[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]

For each set of iterations in nonfull_sets, indicate the end idx belonging to that set (i.e., the last idx would be end - 1). Note that the final masked_set_ends may not necessarilly equal total_iters, if we have UNKNOWN_MASKs. In case of UNKNOWN_MASKs, masked_set_ends with tile-skipping must be used to have the correct kv_row values at each iteration.

Returns:

StaticTuple

last_masked_set_end

last_masked_set_end[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32

Equivalent to masked_set_ends[BM,BN,page_size](row, num_cols)[-1].

Returns:

UInt32

nonfull_sets

static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]

For each set of iterations that are either partially masked or not masked, this indicates the mask status. UNKNOWN_MASK here is an indicator meaning that we should check the status at runtime. It is semantically equivalent to partial, but with the optimization hint that it's worth checking on each iteration at runtime for FULL_MASK (in which case we can skip the tile) or NO_MASK (in which case we can unswitch and avoid masking in an inner loop).

Returns:

StaticTuple

name

static name() -> String

Returns:

String

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

get_device_type_name

static get_device_type_name() -> String

Gets device_type's name. For example, because DeviceBuffer's device_type is UnsafePointer, DeviceBuffer[DType.float32]'s get_device_type_name() should return something like "UnsafePointer[Scalar[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The device type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self.

Returns:

_Self: A copy of this value.

Was this page helpful?