IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

KVCacheMHAOperand

struct KVCacheMHAOperand[cache_t: KVCacheT]

An implementation for mo.opaque KVCacheT arguments to MHA kernels.

We can eventually remove this trait and just add it as a sub-trait in the KVCacheT type, but we need to solve some cyclic dependencies first.

Fields​

  • ​cache (cache_t):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHAOperand, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = KVCacheMHAOperand[cache_t]

dtype​

comptime dtype = cache_t.dtype

page_size​

comptime page_size = cache_t.page_size_

quantization_enabled​

comptime quantization_enabled = cache_t.quantization_enabled

quantization_granularity​

comptime quantization_granularity = cache_t.quantization_granularity

scale_dtype​

comptime scale_dtype = cache_t.scale_dtype

Methods​

__init__​

def __init__(cache: cache_t) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]

load_scale​

def load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[Self.scale_dtype, width]

Returns:

SIMD[Self.scale_dtype, width]

cache_length​

def cache_length(self, batch_idx: Int) -> Int

Returns:

Int

max_context_length​

def max_context_length(self) -> UInt32

Returns:

UInt32

num_kv_rows​

def num_kv_rows(self) -> Int

Returns the total number of virtual rows in the KV memory view.

Returns:

Int

row_idx​

def row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

populate​

def populate[BN: Int, base_alignment: Int, pair_cta: Bool = False, is_leader: Bool = True](self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, cache_t.page_size_, pair_cta, is_leader]

Delegate to the underlying cache's populate.

PagedKVCache.populate overrides with a SIMD lookup-table read; other cache types fall through to the scalar default on KVCacheT.populate. base_alignment is the comptime alignment of base_kv_row (typically mask.start_column_alignment[...]()).

Returns:

PagedRowIndices[BN, cache_t.page_size_, pair_cta, is_leader]

get_tma_row​

def get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

Returns:

Int32

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[cache_t.dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Creates a TMA tile for efficient GPU memory transfers.

Returns:

TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_scale_tma_tile​

def create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)])

Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)]

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[cache_t.dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Delegates to the underlying KVCache to create a BF16 rope TMA tile.

Returns:

TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), tma_dtype: DType = KVCacheMHAOperand[cache_t].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Creates a 2D TMA gather4 descriptor for this KV cache operand.

Returns:

TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Delegates to the underlying KVCache to create a BF16 rope gather4 TMA tile.

Returns:

TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

scales_raw_ptr​

def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns the base pointer to the quantization scales tensor.

Returns:

UnsafePointer[Float32, MutAnyOrigin]