Mojo trait
MHAOperand
This serves as the trait to support arguments to our MHA kernel.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
device_type
comptime device_type
Indicate the type being used on accelerator devices.
dtype
comptime dtype
page_size
comptime page_size
quantization_enabled
comptime quantization_enabled = False
quantization_granularity
comptime quantization_granularity
scale_dtype
comptime scale_dtype
Required methods
__init__
__init__(out self: _Self, *, copy: _Self)
Create a new instance of the value by copying an existing one.
Args:
- copy (
_Self): The value to copy.
Returns:
_Self
__init__(out self: _Self, *, deinit take: _Self)
Create a new instance of the value by moving the value of another.
Args:
- take (
_Self): The value to move.
Returns:
_Self
block_paged_ptr
block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]
Returns:
scales_block_paged_ptr
scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], ImmutAnyOrigin]
Returns:
load_scale
load_scale[width: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]
Returns:
cache_length
cache_length(self: _Self, batch_idx: Int) -> Int
Returns the length of the cache for a given batch index.
Returns:
max_context_length
max_context_length(self: _Self) -> UInt32
Returns the maximum cache length in a given batch index.
Returns:
num_kv_rows
num_kv_rows(self: _Self) -> Int
Returns the total number of virtual rows in the KV memory view.
For paged caches this accounts for the paging stride so that TMA descriptors can be sized to cover the entire address space.
Returns:
row_idx
row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
create_tma_tile
create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode]()]
Creates a TMA tile for efficient GPU memory transfers. This is useful for k-major MMA operations where we don't need to mask any extra rows.
Returns:
create_scale_tma_tile
create_scale_tma_tile[BMN: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](VariadicPack(1, BMN))]
Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.
Returns:
create_ragged_tma_tile
create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BN, BK]
Creates a TMA tile for efficient GPU memory transfers. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.
Returns:
create_rope_tma_tile
create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode]()]
Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.
Returns:
create_gather4_tma_tile
create_gather4_tma_tile[row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 2, IndexList(VariadicList(4, row_width), Tuple()), IndexList(VariadicList(1, row_width), Tuple())]
Creates a 2D TMA gather4 descriptor for this operand.
The descriptor views the data as a flat 2D matrix of
[num_kv_rows, row_width] and is configured for gather4 operations
that load 4 non-contiguous rows per TMA instruction.
Parameters:
- row_width (
Int): Number of elements per row (innermost dimension). - swizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
Args:
- ctx (
DeviceContext): The CUDA device context used to create the TMA descriptor.
Returns:
TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and
desc_shape=(1, row_width).
scales_raw_ptr
scales_raw_ptr(self: _Self) -> UnsafePointer[Float32, MutAnyOrigin]
Returns the base pointer to the quantization scales tensor.
Returns a null pointer for operands without quantization support.
Returns:
get_type_name
static get_type_name() -> String
Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.
Returns:
String: The host type's name.
Provided methods
copy
copy(self: _Self) -> _Self
Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.
Returns:
_Self: A copy of this value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!