Skip to main content

Mojo trait

MHAOperand

This serves as the trait to support arguments to our MHA kernel.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

device_type

comptime device_type

Indicate the type being used on accelerator devices.

dtype

comptime dtype

page_size

comptime page_size

quantization_enabled

comptime quantization_enabled = False

quantization_granularity

comptime quantization_granularity

scale_dtype

comptime scale_dtype

Required methods

__init__

__init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • copy (_Self): The value to copy.

Returns:

_Self

__init__(out self: _Self, *, deinit take: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • take (_Self): The value to move.

Returns:

_Self

block_paged_ptr

block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer

scales_block_paged_ptr

scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], ImmutAnyOrigin]

Returns:

UnsafePointer

load_scale

load_scale[width: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]

Returns:

SIMD

cache_length

cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

max_context_length

max_context_length(self: _Self) -> UInt32

Returns the maximum cache length in a given batch index.

Returns:

UInt32

num_kv_rows

num_kv_rows(self: _Self) -> Int

Returns the total number of virtual rows in the KV memory view.

For paged caches this accounts for the paging stride so that TMA descriptors can be sized to cover the entire address space.

Returns:

Int

row_idx

row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

create_tma_tile

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode]()]

Creates a TMA tile for efficient GPU memory transfers. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile

create_scale_tma_tile

create_scale_tma_tile[BMN: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](VariadicPack(1, BMN))]

Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile

create_ragged_tma_tile

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BN, BK]

Creates a TMA tile for efficient GPU memory transfers. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile

create_rope_tma_tile

create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode]()]

Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.

Returns:

TMATensorTile

create_gather4_tma_tile

create_gather4_tma_tile[row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 2, IndexList(VariadicList(4, row_width), Tuple()), IndexList(VariadicList(1, row_width), Tuple())]

Creates a 2D TMA gather4 descriptor for this operand.

The descriptor views the data as a flat 2D matrix of [num_kv_rows, row_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction.

Parameters:

  • row_width (Int): Number of elements per row (innermost dimension).
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.

Args:

  • ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and desc_shape=(1, row_width).

scales_raw_ptr

scales_raw_ptr(self: _Self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns the base pointer to the quantization scales tensor.

Returns a null pointer for operands without quantization support.

Returns:

UnsafePointer

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Returns:

_Self: A copy of this value.

Was this page helpful?