Mojo module
kv_buffer
KV cache buffers for MHA/MLA prefill and decode kernels.
Provides KVCacheIterator (TileTensor-based DRAM tile iteration), the KVBufferConfig trait + K/V implementors, and three KV buffer structs:
- KVBuffer: double-buffered DMA + LDS + register tile management used by MHA/MLA prefill (owns its DRAM iterator).
- DecodeStreamingKVBuffer: single-buffer per-strip DMA used by the streaming decode kernel (takes an external DRAM tile per iteration).
- DecodeKVBuffer: double-buffered register staging used by the decode mirror path (parametrized by KVBufferConfig for K vs V roles).
TileTensor is used throughout:
- DRAM tiles: TileTensor with RuntimeInt valid_rows (KVCacheIterator)
- SMEM sub-tiles:
.tile()views on a strided parent TileTensor that mirrors the blocked (BN Γ BK) SMEM layout - DMA: SubTileLoaderLDS / RegTileLoader (both src and dst are TileTensor)
- LDS loads: KVMmaOp.load_prefill / load_v_bf16 / load_v_fp8_strip (TileTensor SMEM -> reg-tile fragments)
- MMA register tiles: TileTensor in LOCAL with stack_allocation
comptime valuesβ
KBufferβ
comptime KBuffer[dtype: DType, kv_tile_layout: TensorLayout, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = DecodeKVBuffer[KBufferConfig[BN, BK, WN], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]
Parametersβ
- βdtype (
DType): - βkv_tile_layout (
TensorLayout): - βtensor_core_mma (
TiledTensorCore[out_type, in_type, shape, group_size, transpose_b]): - βswizzle (
Optional[Swizzle]): - βBN (
Int): - βWN (
Int): - βBK (
Int): - βdepth (
Int): - βnum_threads (
Int): - βnum_stages (
Int): - βtoken_gen (
Bool):
VBufferβ
comptime VBuffer[dtype: DType, kv_tile_layout: TensorLayout, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = DecodeKVBuffer[VBufferConfig[BN, BK, WN, depth], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]
Parametersβ
- βdtype (
DType): - βkv_tile_layout (
TensorLayout): - βtensor_core_mma (
TiledTensorCore[out_type, in_type, shape, group_size, transpose_b]): - βswizzle (
Optional[Swizzle]): - βBN (
Int): - βWN (
Int): - βBK (
Int): - βdepth (
Int): - βnum_threads (
Int): - βnum_stages (
Int): - βtoken_gen (
Bool):
Structsβ
- β
DecodeKVBuffer: - β
DecodeStreamingKVBuffer: Streaming-decode KV buffer: single-buffer SMEM staging with per-strip DMA. - β
KBufferConfig: - β
KVBuffer: KV cache buffer managing DMA, LDS staging, and register tiles. - β
KVCacheIterator: TileTensor-based DRAM tile iterator. - β
VBufferConfig:
Traitsβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!