For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
KVCacheMHAOperand
struct KVCacheMHAOperand[cache_t: KVCacheT]
An implementation for mo.opaque KVCacheT arguments to MHA kernels.
We can eventually remove this trait and just add it as a sub-trait in the KVCacheT type, but we need to solve some cyclic dependencies first.
Fieldsβ
- βcache (
cache_t):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
MHAOperand,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = KVCacheMHAOperand[cache_t]
dtypeβ
comptime dtype = cache_t.dtype
page_sizeβ
comptime page_size = cache_t.page_size_
quantization_enabledβ
comptime quantization_enabled = cache_t.quantization_enabled
quantization_granularityβ
comptime quantization_granularity = cache_t.quantization_granularity
scale_dtypeβ
comptime scale_dtype = cache_t.scale_dtype
Methodsβ
__init__β
def __init__(cache: cache_t) -> Self
get_type_nameβ
block_paged_ptrβ
def block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]
Returns:
scales_block_paged_ptrβ
def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]
Returns:
load_scaleβ
def load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[Self.scale_dtype, width]
Returns:
cache_lengthβ
max_context_lengthβ
num_kv_rowsβ
def num_kv_rows(self) -> Int
Returns the total number of virtual rows in the KV memory view.
Returns:
row_idxβ
def row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
populateβ
def populate[BN: Int, base_alignment: Int, pair_cta: Bool = False, is_leader: Bool = True](self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, cache_t.page_size_, pair_cta, is_leader]
Delegate to the underlying cache's populate.
PagedKVCache.populate overrides with a SIMD lookup-table read;
other cache types fall through to the scalar default on
KVCacheT.populate. base_alignment is the comptime alignment
of base_kv_row (typically mask.start_column_alignment[...]()).
Returns:
PagedRowIndices[BN, cache_t.page_size_, pair_cta, is_leader]
get_tma_rowβ
def get_tma_row(self, encoded_index: Int32) -> Int32
Convert an encoded sparse index to a physical TMA row.
Returns:
create_tma_tileβ
def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[cache_t.dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])
Creates a TMA tile for efficient GPU memory transfers.
Returns:
create_scale_tma_tileβ
def create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)])
Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.
Returns:
TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)]
create_ragged_tma_tileβ
def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[cache_t.dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])
Returns:
create_rope_tma_tileβ
def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])
Delegates to the underlying KVCache to create a BF16 rope TMA tile.
Returns:
create_gather4_tma_tileβ
def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), tma_dtype: DType = KVCacheMHAOperand[cache_t].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])
Creates a 2D TMA gather4 descriptor for this KV cache operand.
Returns:
create_rope_gather4_tma_tileβ
def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])
Delegates to the underlying KVCache to create a BF16 rope gather4 TMA tile.
Returns:
scales_raw_ptrβ
def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]
Returns the base pointer to the quantization scales tensor.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!