Mojo struct
KVCacheMHAOperand
struct KVCacheMHAOperand[cache_t: KVCacheT]
An implementation for mo.opaque KVCacheT arguments to MHA kernels.
We can eventually remove this trait and just add it as a sub-trait in the KVCacheT type, but we need to solve some cyclic dependencies first.
Fieldsβ
- βcache (
cache_t):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
MHAOperand,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = KVCacheMHAOperand[cache_t]
dtypeβ
comptime dtype = cache_t.dtype
page_sizeβ
comptime page_size = cache_t.page_size_
quantization_enabledβ
comptime quantization_enabled = cache_t.quantization_enabled
quantization_granularityβ
comptime quantization_granularity = cache_t.quantization_granularity
scale_dtypeβ
comptime scale_dtype = cache_t.scale_dtype
Methodsβ
__init__β
__init__(cache: cache_t) -> Self
get_type_nameβ
block_paged_ptrβ
block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].dtype], ImmutAnyOrigin]
Returns:
UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].dtype], ImmutAnyOrigin]
scales_block_paged_ptrβ
scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].scale_dtype], ImmutAnyOrigin]
Returns:
UnsafePointer[Scalar[KVCacheMHAOperand[cache_t].scale_dtype], ImmutAnyOrigin]
load_scaleβ
load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[KVCacheMHAOperand[cache_t].scale_dtype, width]
Returns:
cache_lengthβ
max_context_lengthβ
num_kv_rowsβ
num_kv_rows(self) -> Int
Returns the total number of virtual rows in the KV memory view.
Returns:
row_idxβ
row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
populateβ
populate[BN: Int, pair_cta: Bool = False, is_leader: Bool = True](self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, KVCacheMHAOperand[cache_t].page_size, pair_cta, is_leader]
Delegate to the underlying cache's populate.
PagedKVCache.populate overrides with a SIMD lookup-table read;
other cache types fall through to the scalar default on
KVCacheT.populate.
Returns:
PagedRowIndices[BN, KVCacheMHAOperand[cache_t].page_size, pair_cta, is_leader]
get_tma_rowβ
get_tma_row(self, encoded_index: Int32) -> Int32
Convert an encoded sparse index to a physical TMA row.
Returns:
create_tma_tileβ
create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[KVCacheMHAOperand[cache_t].dtype, 3, _padded_shape[3, KVCacheMHAOperand[cache_t].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, KVCacheMHAOperand[cache_t].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])
Creates a TMA tile for efficient GPU memory transfers.
Returns:
create_scale_tma_tileβ
create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[KVCacheMHAOperand[cache_t].scale_dtype, 2, Index[Int, Int](1, BMN)])
Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.
Returns:
TMATensorTile[KVCacheMHAOperand[cache_t].scale_dtype, 2, Index[Int, Int](1, BMN)]
create_ragged_tma_tileβ
create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, BM=BN, BN=BK])
Returns:
RaggedTMA3DTile[KVCacheMHAOperand[cache_t].dtype, swizzle_mode, BM=BN, BN=BK]
create_rope_tma_tileβ
create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])
Delegates to the underlying KVCache to create a BF16 rope TMA tile.
Returns:
create_gather4_tma_tileβ
create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, tma_dtype: DType = KVCacheMHAOperand[cache_t].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])
Creates a 2D TMA gather4 descriptor for this KV cache operand.
Returns:
create_rope_gather4_tma_tileβ
create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])
Delegates to the underlying KVCache to create a BF16 rope gather4 TMA tile.
Returns:
scales_raw_ptrβ
scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]
Returns the base pointer to the quantization scales tensor.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!