IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

KVBuffer

struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool, full_kv: Bool = True, cache_depth: Int = depth, head_dim_offset: Int = 0, reg_chunk_depth: Int = depth, smem_depth: Int = depth, bk_smem: Int = BK]

KV cache buffer managing DMA, LDS staging, and register tiles.

Handles the full data path: DRAM -> LDS (shared memory) -> registers.

SMEM is navigated via .tile() on a strided parent TileTensor whose (BK, BN) strides encode the blocked layout (num_repeats contiguous BNΓ—BK blocks per stage, two stages). Stage selection and in-stage block selection both happen via the tile column index β€” no pointer arithmetic required. smem_mma_subtile is still used for V-operand MMA sub-tiles which have mma_cols != BK.

When full_kv=True (depth<=256), each SMEM stage holds BN x smem_depth elements β€” the full tile. When full_kv=False (depth=512), each stage holds only BN x BK elements, and the caller iterates over BK blocks.

smem_depth defaults to depth. It exists for per-warp V buffers whose depth is smaller than BK (e.g. depth_per_warp=16 with BK=32): the SMEM layout stays valid with smem_depth = max(depth, BK) while depth keeps driving register-tile sizing and the column count read by load_from_shared.

MMA register tiles (mma_tile) are TileTensor in LOCAL address space. TiledMmaOp (mma.mojo) handles SMEM→register loads and MMA dispatch.

Fields​

  • ​kv_mma_op (KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].KVMmaOpType):
  • ​smem_tile (KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].SmemParentType):
  • ​kv_cache_iter (KVCacheIterator[kv_t, BN, kv_num_heads, depth, cache_depth, head_dim_offset]):
  • ​warp_id (UInt32):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

input_frag_size​

comptime input_frag_size = ((KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].MMA_K * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].MMA_N) // WARP_SIZE)

KVMmaOpType​

comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_mmas, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_k_mmas2, ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_k_tiles, BN, BK, transpose, swizzle]

MMA_K​

comptime MMA_K = mma_shape[2]

mma_layout​

comptime mma_layout = row_major[((KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_mmas * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_k_mmas2) * ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_k_tiles), KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].input_frag_size]()

MMA_N​

comptime MMA_N = mma_shape[1]

MMATileType​

comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

num_k_mmas2​

comptime num_k_mmas2 = ceildiv(BK, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].MMA_K)

num_k_tiles​

comptime num_k_tiles = ceildiv(depth if transpose else WN, BK)

num_mmas​

comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].MMA_N)

num_repeats​

comptime num_repeats = (smem_depth // bk_smem)

simd_width​

comptime simd_width = simd_width_of[kv_t.dtype]()

smem_cols​

comptime smem_cols = smem_depth if full_kv else bk_smem

smem_stage_size​

comptime smem_stage_size = (BN * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].smem_cols)

SmemParentType​

comptime SmemParentType = TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

vm_instrs_per_load​

comptime vm_instrs_per_load = SIMD((ceildiv(((BN // 32) * (smem_depth // bk_smem) if full_kv else 1), (num_threads // WARP_SIZE)) * 2))

warp_tile_rows​

comptime warp_tile_rows = 32

wtile_dim0​

comptime wtile_dim0 = WN

wtile_dim1​

comptime wtile_dim1 = BK

Methods​

__init__​

def __init__(out self, k_cache: kv_t, batch_idx: Int, head_idx: Int, smem_tile: TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], end: Int, warp_id: UInt32)

smem_block_tile​

def smem_block_tile[tile_rows: Int](self, tile_row: Int, block_col: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Get a (tile_rows, bk_smem) row-major sub-tile from SMEM.

tile_row indexes along BN (rows within a BNΓ—bk_smem block), block_col indexes linearly across all BNΓ—bk_smem blocks in both stages.

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

load_from_dram​

def load_from_dram[buffer_idx: Int](mut self)

get_mma_tile​

def get_mma_tile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

mma_subtile​

def mma_subtile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Alias for get_mma_tile, kept for decode-call-site symmetry.

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

zero_partial_tile_pad​

def zero_partial_tile_pad(self)

Register-side zero for the OOB tail of the partial K-tile.

When depth % BK != 0, the last K-tile (i = depth // BK) spans BK K-positions but only valid_cols = depth - i*BK are valid; the trailing BK - valid_cols are pad and must read as zero.

Per-lane K-fragment layout: the input_frag_size elements per lane interleave across MFMA-K such that the LOWER valid_per_lane = input_frag_size * valid_cols / BK elements correspond to the valid K-range and the UPPER input_frag_size - valid_per_lane elements correspond to the pad. Zero the upper portion per lane.

The reference pre-zeros half of the partial-tile K-fragment dwords once and reuses; we re-zero each K LDS load because the LDS loader fills the whole reg tile.

A no-op when depth % BK == 0.

NOTE: keep in sync with QRegisterBuffer.__init__'s partial-tile zero in buffers.mojo. Both sites share the upper-half-is-pad assumption (asserted below); a future config that violates it (valid_cols > BK/2) needs a different zero pattern in both.

load_from_shared​

def load_from_shared(self, buffer: Int)

def load_from_shared[bk_tile: Int](self, buffer: Int)