IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo trait

KVCacheT

Trait for different KVCache types and implementations.

Represents a single (key or value) cache.

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type

Indicate the type being used on accelerator devices.

dtype​

comptime dtype

kv_params​

comptime kv_params

page_size_​

comptime page_size_

quantization_enabled​

comptime quantization_enabled = False

quantization_granularity​

comptime quantization_granularity = Int(1)

scale_dtype​

comptime scale_dtype = DType.invalid

Required methods​

__init__​

def __init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • ​copy (_Self): The value to copy.

Returns:

_Self

def __init__(out self: _Self, *, deinit move: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • ​move (_Self): The value to move.

Returns:

_Self

cache_lengths_nd​

def cache_lengths_nd(self: _Self) -> TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]

Returns the cache lengths as a TileTensor.

Returns:

TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]

cache_length​

def cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

load​

def load[width: Int, output_dtype: DType = _Self.dtype](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Loads an element from the given index.

Returns:

SIMD[output_dtype, width]

store​

def store(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[_Self.dtype])

Stores an element at the given index.

store_scale​

def store_scale(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[_Self.scale_dtype])

Stores the quantization scales at the given index.

load_scale​

def load_scale[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]

Loads the quantization scales from the given index.

Returns:

SIMD[_Self.scale_dtype, width]

load_quantized​

def load_quantized[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.dtype, width]

Loads a quantized element from the given index.

Returns:

SIMD[_Self.dtype, width]

empty_cache​

def empty_cache(self: _Self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length​

def max_prompt_length(self: _Self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length​

def max_context_length(self: _Self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[_Self.dtype], MutAnyOrigin]

Returns a pointer to the KVCache block at the given index.

Paged KVCache implementations must have a block_size which is a multiple of the and greater than the layout's first dimension.

Returns:

UnsafePointer[Scalar[_Self.dtype], MutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Returns:

UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

scales_raw_ptr​

def scales_raw_ptr(self: _Self) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns the base pointer to the scales tensor.

For PagedKVCache with quantization enabled, this returns the raw base pointer of the scales TileTensor. For caches without quantization, returns a null pointer.

Returns:

UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

max_tile_size​

static def max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

num_kv_rows​

def num_kv_rows(self: _Self) -> Int

Returns the total number of virtual rows in this KV cache view.

For paged caches this accounts for the paging stride: (total_blocks - 1) * stride + page_size.

Returns:

Int

row_idx​

def row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

get_tma_row​

def get_tma_row(self: _Self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

For paged caches the encoded index is physical_block * page_size + offset and this method returns physical_block * stride + offset. Non-paged caches return the encoded index unchanged.

Returns:

Int32

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, _Self.kv_params.head_size]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, Int(3), _padded_shape[Int(3), _Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), _Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a TMA tile for this KV cache. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile[_Self.dtype, Int(3), _padded_shape[Int(3), _Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), _Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, _Self.kv_params.head_size]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]

Creates a TMA tile for this KV cache. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a BF16 TMA tile for the rope portion of the KV cache.

For the per-tensor rope-aware layout, each token row in the KV cache is stored as padded_depth FP8 bytes (content) followed by BK BF16 elements (rope). This method returns a TMA descriptor that points at the rope data starting at byte offset padded_depth within each row, reinterpreted as BF16.

Returns:

TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[*, tile_height: Int = Int(4), tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tma_dtype: DType = _Self.dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a 2D TMA gather4 descriptor for this KV cache.

The descriptor views the KV cache as a flat 2D matrix of [num_kv_rows, tile_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction. The box width is derived from the swizzle mode; for SWIZZLE_NONE it equals tile_width.

The tile_height parameter records the full tile height (e.g. 64 rows) in the returned TMATensorTile.tile_shape. The hardware descriptor shape stays (1, box_width) as required by TMA gather4.

When tma_dtype differs from Self.dtype, the underlying data pointer is bitcast to tma_dtype at descriptor creation time. This allows, for example, creating an INT64/SWIZZLE_NONE descriptor over FP8 data for linear SMEM layout.

Parameters:

  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • ​tile_width (Int): Number of elements per row to load (box width) in tma_dtype elements.
  • ​tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the global row is wider than the portion to load.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
  • ​tma_dtype (DType): The data type used for the TMA descriptor. Defaults to Self.dtype. When different, the pointer is bitcast.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[*, tile_height: Int = Int(4), tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.

For the per-tensor rope-aware layout each token row is stored as padded_depth FP8 bytes (content) followed by BF16 rope elements. This method offsets the base pointer by padded_depth bytes, reinterprets as BF16, and creates a gather4 TMA descriptor with tile_width BF16 elements per row.

Parameters:

  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4.
  • ​tile_width (Int): Number of BF16 elements per row in global memory.
  • ​padded_depth (Int): Byte offset from row start to the rope data.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A BF16 TMATensorTile configured for gather4.

get_type_name​

static def get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods​

populate​

def populate[BN: Int, base_alignment: Int, pair_cta: Bool = False, is_leader: Bool = True](self: _Self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, _Self.page_size_, pair_cta, is_leader]

Populate a full PagedRowIndices[BN, ...] for a BN-row tile.

base_alignment is a comptime promise that base_kv_row % base_alignment == 0 at runtime β€” typically mask.start_column_alignment[...](). The PagedKVCache override uses it to pick the largest legal SIMD chunk for its LUT vector load and to skip the intra-page divmod when base_alignment % page_size == 0.

Default: scalar loop over num_pages calls to row_idx. The PagedKVCache override replaces this with a single aligned SIMD load against the lookup table.

Returns:

PagedRowIndices[BN, _Self.page_size_, pair_cta, is_leader]

copy​

def copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Overriding this method is not allowed.

Returns:

_Self: A copy of this value.