Skip to main content

Mojo trait

MHAOperand

This serves as the trait to support arguments to our MHA kernel.

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type

Indicate the type being used on accelerator devices.

dtype​

comptime dtype

page_size​

comptime page_size

quantization_enabled​

comptime quantization_enabled = False

quantization_granularity​

comptime quantization_granularity

scale_dtype​

comptime scale_dtype

Required methods​

__init__​

__init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • ​copy (_Self): The value to copy.

Returns:

_Self

__init__(out self: _Self, *, deinit take: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • ​take (_Self): The value to move.

Returns:

_Self

block_paged_ptr​

block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]

scales_block_paged_ptr​

scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[_Self.scale_dtype], ImmutAnyOrigin]

load_scale​

load_scale[width: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]

Returns:

SIMD[_Self.scale_dtype, width]

cache_length​

cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

max_context_length​

max_context_length(self: _Self) -> UInt32

Returns the maximum cache length in a given batch index.

Returns:

UInt32

num_kv_rows​

num_kv_rows(self: _Self) -> Int

Returns the total number of virtual rows in the KV memory view.

For paged caches this accounts for the paging stride so that TMA descriptors can be sized to cover the entire address space.

Returns:

Int

row_idx​

row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

get_tma_row​

get_tma_row(self: _Self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

For paged caches the encoded index is physical_block * page_size + offset and this method returns physical_block * stride + offset. Non-paged operands return the encoded index unchanged.

Returns:

Int32

create_tma_tile​

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a TMA tile for efficient GPU memory transfers. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_scale_tma_tile​

create_scale_tma_tile[BMN: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](1, BMN)]

Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](1, BMN)]

create_ragged_tma_tile​

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]

Creates a TMA tile for efficient GPU memory transfers. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.

Returns:

TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, tma_dtype: DType = _Self.dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a 2D TMA gather4 descriptor for this operand.

The descriptor views the data as a flat 2D matrix of [num_kv_rows, tile_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction. The box width is derived from the swizzle mode; for SWIZZLE_NONE it equals tile_width.

When tma_dtype differs from Self.dtype, the underlying data pointer is bitcast to tma_dtype at descriptor creation time.

Parameters:

  • ​tile_width (Int): Number of elements per row to load (box width) in tma_dtype elements.
  • ​tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the global row is wider than the portion to load.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • ​tma_dtype (DType): The data type used for the TMA descriptor. Defaults to Self.dtype. When different, the pointer is bitcast.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.

create_rope_gather4_tma_tile​

create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.

For the per-tensor rope-aware layout each token row is padded_depth FP8 bytes (content) followed by BF16 rope elements. This method offsets the base pointer by padded_depth bytes, reinterprets as BF16, and creates a gather4 TMA descriptor.

Parameters:

  • ​tile_width (Int): Number of BF16 elements per row in global memory.
  • ​padded_depth (Int): Byte offset from row start to the rope data.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern.
  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A BF16 TMATensorTile configured for gather4.

scales_raw_ptr​

scales_raw_ptr(self: _Self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns the base pointer to the quantization scales tensor.

Returns a null pointer for operands without quantization support.

Returns:

UnsafePointer[Float32, MutAnyOrigin]

get_type_name​

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods​

block_paged_tile​

block_paged_tile[layout_t: TensorLayout, //, tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, layout_val: layout_t, head_dim_idx: UInt32 = UInt32(0)) -> TileTensor[_Self.dtype, layout_t, ImmutAnyOrigin]

Wraps block_paged_ptr in a TileTensor with the caller's layout.

Returns:

TileTensor[_Self.dtype, layout_t, ImmutAnyOrigin]

populate​

populate[BN: Int, pair_cta: Bool = False, is_leader: Bool = True](self: _Self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, _Self.page_size, pair_cta, is_leader]

Populate a full PagedRowIndices[BN, ...] for a BN-row tile.

Returns the precomputed physical row indices for the num_pages sub-tile pages covering the BN-row range starting at base_kv_row for batch_idx. Both K's TMA (which may cover only a subset in pair_cta mode) and V's TMA (which covers the full range) can then consume the result without any lazy LUT lookup.

Default implementation: scalar loop over num_pages calls to row_idx. Overrides (e.g. PagedKVCache) replace this with a single SIMD load from the underlying lookup table.

Returns:

PagedRowIndices[BN, _Self.page_size, pair_cta, is_leader]

copy​

copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Returns:

_Self: A copy of this value.