Skip to main content

Mojo module

kv_buffer

KV cache buffers for MHA/MLA prefill and decode kernels.

Provides KVCacheIterator (TileTensor-based DRAM tile iteration), the KVBufferConfig trait + K/V implementors, and three KV buffer structs:

  • KVBuffer: double-buffered DMA + LDS + register tile management used by MHA/MLA prefill (owns its DRAM iterator).
  • DecodeStreamingKVBuffer: single-buffer per-strip DMA used by the streaming decode kernel (takes an external DRAM tile per iteration).
  • DecodeKVBuffer: double-buffered register staging used by the decode mirror path (parametrized by KVBufferConfig for K vs V roles).

TileTensor is used throughout:

  • DRAM tiles: TileTensor with RuntimeInt valid_rows (KVCacheIterator)
  • SMEM sub-tiles: .tile() views on a strided parent TileTensor that mirrors the blocked (BN Γ— BK) SMEM layout
  • DMA: SubTileLoaderLDS / RegTileLoader (both src and dst are TileTensor)
  • LDS loads: KVMmaOp.load_prefill / load_v_bf16 / load_v_fp8_strip (TileTensor SMEM -> reg-tile fragments)
  • MMA register tiles: TileTensor in LOCAL with stack_allocation

comptime values​

KBuffer​

comptime KBuffer[dtype: DType, kv_tile_layout: TensorLayout, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = DecodeKVBuffer[KBufferConfig[BN, BK, WN], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]

Parameters​

VBuffer​

comptime VBuffer[dtype: DType, kv_tile_layout: TensorLayout, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = DecodeKVBuffer[VBufferConfig[BN, BK, WN, depth], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]

Parameters​

Structs​

Traits​