Mojo struct
KVBuffer
struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], k_group_size: Int, swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]
KV cache buffer managing DMA, LDS staging, and register tiles.
Handles the full data path: DRAM -> LDS (shared memory) -> registers.
SMEM is addressed via smem_subtile / smem_mma_subtile which compute block-aligned offsets from the blocked_product layout structure (num_repeats contiguous BN×BK blocks) and return flat TileTensor views.
MMA register tiles (mma_tile) are TileTensor in LOCAL address space. TiledTensorCore.mma() has TileTensor overloads that construct LayoutTensor views at the MMA boundary (tensor_core.mojo).
Fields
- mma_tile (
KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMATileType): - smem_ptr (
UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]): - kv_cache_iter (
KVCacheIterator[kv_t, BN, kv_num_heads, depth]): - lds_base_ptrs (
InlineArray[UInt32, 2]): - warp_id (
UInt32):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
MMA_K
comptime MMA_K = mma_shape[2]
mma_layout
comptime mma_layout = row_major[((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width]()
MMA_N
comptime MMA_N = mma_shape[1]
MMATileType
comptime MMATileType = TileTensor[kv_t.dtype, Layout[ComptimeInt[((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2) * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_tiles)], ComptimeInt[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width], ComptimeInt[(ComptimeInt[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]], MutExternalOrigin, address_space=AddressSpace.LOCAL]
num_k_mmas2
comptime num_k_mmas2 = ceildiv(BK, (KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K * k_group_size))
num_k_tiles
comptime num_k_tiles = ceildiv(depth if transpose else WN, BK)
num_mmas
comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N)
num_repeats
comptime num_repeats = (depth // BK)
simd_width
comptime simd_width = simd_width_of[kv_t.dtype]()
smem_stage_size
comptime smem_stage_size = (BN * depth)
vm_instrs_per_load
comptime vm_instrs_per_load = SIMD((ceildiv(((BN // 32) * (depth // BK)), (num_threads // WARP_SIZE)) * 2))
warp_tile_rows
comptime warp_tile_rows = 32
wtile_dim0
comptime wtile_dim0 = WN
wtile_dim1
comptime wtile_dim1 = BK
Methods
__init__
__init__(out self, k_cache: kv_t, batch_idx: UInt, head_idx: UInt, shared_ptr: UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], end: UInt, warp_id: UInt32)
load_from_dram
load_from_dram[buffer_idx: Int](mut self)
get_mma_tile
get_mma_tile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[ComptimeInt[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas], ComptimeInt[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width], ComptimeInt[(ComptimeInt[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Returns:
copy_to_shared
copy_to_shared(self)
load_from_shared
load_from_shared(self, buffer: UInt)
load_from_shared[bk_tile: Int](self, buffer: UInt)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!