Skip to main content

Mojo struct

KVCacheMHAOperand

struct KVCacheMHAOperand[cache_t: KVCacheT]

An implementation for mo.opaque KVCacheT arguments to MHA kernels.

We can eventually remove this trait and just add it as a sub-trait in the KVCacheT type, but we need to solve some cyclic dependencies first.

Fields​

  • ​cache (cache_t):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, MHAOperand, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = KVCacheMHAOperand[cache_t]

dtype​

comptime dtype = cache_t.dtype

page_size​

comptime page_size = cache_t.page_size_

quantization_enabled​

comptime quantization_enabled = cache_t.quantization_enabled

quantization_granularity​

comptime quantization_granularity = cache_t.quantization_granularity

scale_dtype​

comptime scale_dtype = cache_t.scale_dtype

Methods​

__init__​

__init__(cache: cache_t) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

block_paged_ptr​

block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].dtype], ImmutAnyOrigin]

scales_block_paged_ptr​

scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].scale_dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].scale_dtype], ImmutAnyOrigin]

load_scale​

load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[KVCacheMHAOperand[cache_t].scale_dtype, width]

Returns:

SIMD[KVCacheMHAOperand[cache_t].scale_dtype, width]

cache_length​

cache_length(self, batch_idx: Int) -> Int

Returns:

Int

max_context_length​

max_context_length(self) -> UInt32

Returns:

UInt32

num_kv_rows​

num_kv_rows(self) -> Int

Returns the total number of virtual rows in the KV memory view.

Returns:

Int

row_idx​

row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

populate​

populate[BN: Int, pair_cta: Bool = False, is_leader: Bool = True](self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, KVCacheMHAOperand[cache_t].page_size, pair_cta, is_leader]

Delegate to the underlying cache's populate.

PagedKVCache.populate overrides with a SIMD lookup-table read; other cache types fall through to the scalar default on KVCacheT.populate.

Returns:

PagedRowIndices[BN, KVCacheMHAOperand[cache_t].page_size, pair_cta, is_leader]

get_tma_row​

get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

Returns:

Int32

create_tma_tile​

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[KVCacheMHAOperand[cache_t].dtype, 3, _padded_shape[3, KVCacheMHAOperand[cache_t].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, KVCacheMHAOperand[cache_t].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Creates a TMA tile for efficient GPU memory transfers.

Returns:

TMATensorTile[KVCacheMHAOperand[cache_t].dtype, 3, _padded_shape[3, KVCacheMHAOperand[cache_t].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, KVCacheMHAOperand[cache_t].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_scale_tma_tile​

create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[KVCacheMHAOperand[cache_t].scale_dtype, 2, Index[Int, Int](1, BMN)])

Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile[KVCacheMHAOperand[cache_t].scale_dtype, 2, Index[Int, Int](1, BMN)]

create_ragged_tma_tile​

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Delegates to the underlying KVCache to create a BF16 rope TMA tile.

Returns:

TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, tma_dtype: DType = KVCacheMHAOperand[cache_t].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Creates a 2D TMA gather4 descriptor for this KV cache operand.

Returns:

TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

create_rope_gather4_tma_tile​

create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Delegates to the underlying KVCache to create a BF16 rope gather4 TMA tile.

Returns:

TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

scales_raw_ptr​

scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns the base pointer to the quantization scales tensor.

Returns:

UnsafePointer[Float32, MutAnyOrigin]