Skip to main content

Mojo trait

MHAOperand

This serves as the trait to support arguments to our MHA kernel.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __copyinit__ is trivial.

The implementation of __copyinit__ is considered to be trivial if:

  • The struct has a compiler-generated trivial __copyinit__ and all its fields have a trivial __copyinit__ method.

In practice, it means the value can be copied by copying the bits from one location to another without side effects.

__del__is_trivial

comptime __del__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __del__ is trivial.

The implementation of __del__ is considered to be trivial if:

  • The struct has a compiler-generated trivial destructor and all its fields have a trivial __del__ method.

In practice, it means that the __del__ can be considered as no-op.

__moveinit__is_trivial

comptime __moveinit__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __moveinit__ is trivial.

The implementation of __moveinit__ is considered to be trivial if:

  • The struct has a compiler-generated __moveinit__ and all its fields have a trivial __moveinit__ method.

In practice, it means the value can be moved by moving the bits from one location to another without side effects.

device_type

comptime device_type

Indicate the type being used on accelerator devices.

dtype

comptime dtype

page_size

comptime page_size

Required methods

__copyinit__

__copyinit__(out self: _Self, existing: _Self, /)

Create a new instance of the value by copying an existing one.

Args:

  • existing (_Self): The value to copy.

Returns:

_Self

__moveinit__

__moveinit__(out self: _Self, deinit existing: _Self, /)

Create a new instance of the value by moving the value of another.

Args:

  • existing (_Self): The value to move.

Returns:

_Self

block_paged_ptr

block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer

cache_length

cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

max_context_length

max_context_length(self: _Self) -> UInt32

Returns the maximum cache length in a given batch index.

Returns:

UInt32

row_idx

row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

create_tma_tile

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, _split_last_layout[_Self.dtype](IndexList(BN, 1, BK, Tuple()), swizzle_mode, True), _ragged_desc_layout[_Self.dtype](IndexList(BN, 1, BK, Tuple()), swizzle_mode)]

Creates a TMA tile for efficient GPU memory transfers. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile

create_ragged_tma_tile

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BN, BK]

Creates a TMA tile for efficient GPU memory transfers. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self.

Returns:

_Self: A copy of this value.

Was this page helpful?