IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

KVCacheScalesMHAOperand

struct KVCacheScalesMHAOperand[cache_t: KVCacheT]

An MHAOperand that accesses the scales field of a KVCache.

This is useful for MLA attention where k_s (per-token scales) are stored in the scales field of the k cache with quantization_granularity = head_size. The scales have shape [num_blocks, page_size, num_heads, head_dim_granularity].

Fields​

  • ​cache (cache_t):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHAOperand, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = KVCacheScalesMHAOperand[cache_t]

dtype​

comptime dtype = cache_t.scale_dtype

page_size​

comptime page_size = cache_t.page_size_

quantization_enabled​

comptime quantization_enabled = cache_t.quantization_enabled

quantization_granularity​

comptime quantization_granularity = cache_t.quantization_granularity

scale_dtype​

comptime scale_dtype = DType.invalid

Methods​

__init__​

def __init__(cache: cache_t) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[DType.invalid], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[DType.invalid], ImmutAnyOrigin]

load_scale​

def load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[DType.invalid, width]

Returns:

SIMD[DType.invalid, width]

cache_length​

def cache_length(self, batch_idx: Int) -> Int

Returns:

Int

max_context_length​

def max_context_length(self) -> UInt32

Returns:

UInt32

num_kv_rows​

def num_kv_rows(self) -> Int

Returns the total number of virtual rows in the KV memory view.

Returns:

Int

row_idx​

def row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

get_tma_row​

def get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

Returns:

Int32

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[cache_t.scale_dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

TMA not supported for KVCacheScalesMHAOperand.

Returns:

TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_scale_tma_tile​

def create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.invalid, Int(2), Index[Int, Int](Int(1), BMN)])

Returns:

TMATensorTile[DType.invalid, Int(2), Index[Int, Int](Int(1), BMN)]

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[cache_t.scale_dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])

TMA not supported for KVCacheScalesMHAOperand.

Returns:

RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Not supported for KVCacheScalesMHAOperand.

Returns:

TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), tma_dtype: DType = KVCacheScalesMHAOperand[cache_t].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Not supported for KVCacheScalesMHAOperand.

Returns:

TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Not supported for KVCacheScalesMHAOperand.

Returns:

TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

scales_raw_ptr​

def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns a dangling pointer. KVCacheScalesMHAOperand already points to the scales pointer.

Returns:

UnsafePointer[Float32, MutAnyOrigin]