IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

kv_buffer

KV cache buffers for MHA/MLA prefill and decode kernels.

Provides KVCacheIterator (TileTensor-based DRAM tile iteration), the KVBufferConfig trait + K/V implementors, and three KV buffer structs:

  • KVBuffer: double-buffered DMA + LDS + register tile management used by MHA/MLA prefill (owns its DRAM iterator).
  • DecodeStreamingKVBuffer: single-buffer per-strip DMA used by the streaming decode kernel (takes an external DRAM tile per iteration).
  • DecodeKVBuffer: double-buffered register staging used by the decode mirror path (parametrized by KVBufferConfig for K vs V roles).

TileTensor is used throughout:

  • DRAM tiles: TileTensor with Scalar valid_rows (KVCacheIterator)
  • SMEM sub-tiles: .tile() views on a strided parent TileTensor that mirrors the blocked (BN Γ— BK) SMEM layout
  • DMA: SubTileLoaderLDS / RegTileLoader (both src and dst are TileTensor)
  • LDS loads: KVMmaOp.load_prefill / load_v_bf16 / load_v_fp8_strip (TileTensor SMEM -> reg-tile fragments)
  • MMA register tiles: TileTensor in LOCAL with stack_allocation

comptime values​

KBuffer​

comptime KBuffer[dtype: DType, kv_tile_layout: TensorLayout, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = DecodeKVBuffer[KBufferConfig[BN, BK, WN], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]

Parameters​

VBuffer​

comptime VBuffer[dtype: DType, kv_tile_layout: TensorLayout, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = DecodeKVBuffer[VBufferConfig[BN, BK, WN, depth], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]

Parameters​

Structs​

Traits​