Skip to main content

Mojo trait

MHAMask

The MHAMask trait describes masks for MHA kernels, such as the causal mask.

Implemented traits

AnyType, Copyable, DevicePassable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __copyinit__ is trivial.

The implementation of __copyinit__ is considered to be trivial if:

  • The struct has a compiler-generated trivial __copyinit__ and all its fields have a trivial __copyinit__ method.

In practice, it means the value can be copied by copying the bits from one location to another without side effects.

__del__is_trivial

alias __del__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __del__ is trivial.

The implementation of __del__ is considered to be trivial if:

  • The struct has a compiler-generated trivial destructor and all its fields have a trivial __del__ method.

In practice, it means that the __del__ can be considered as no-op.

apply_log2e_after_mask

alias apply_log2e_after_mask

Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?

check_mask_during_decoding

alias check_mask_during_decoding

Should we check the mask during decoding, or should we assume that it does not return FULL_MASK?

device_type

alias device_type

Indicate the type being used on accelerator devices.

mask_out_of_bound

alias mask_out_of_bound

mask_safe_out_of_bounds

alias mask_safe_out_of_bounds

Is the mask safe to read out of bounds?

Required methods

__copyinit__

__copyinit__(out self: _Self, existing: _Self, /)

Create a new instance of the value by copying an existing one.

Args:

  • existing (_Self): The value to copy.

Returns:

_Self

mask

mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]

Return mask vector at given coordinates.

Arguments: coord is (seq_id, head, q_idx, k_idx) score_vec is at coord of the score matrix

The functor could capture an mask tensor and add to the score e.g. Replit.

Returns:

SIMD

status

status[*, element_type: DType = DType.uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus

Given a tile's index range, return its masking status.

Returns:

TileMaskStatus

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

get_device_type_name

static get_device_type_name() -> String

Gets device_type's name. For example, because DeviceBuffer's device_type is UnsafePointer, DeviceBuffer[DType.float32]'s get_device_type_name() should return something like "UnsafePointer[Scalar[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The device type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self.

Returns:

_Self: A copy of this value.

Was this page helpful?