Mojo struct
KVBuffer
struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool, full_kv: Bool = True, cache_depth: Int = depth, head_dim_offset: Int = 0, reg_chunk_depth: Int = depth, smem_depth: Int = depth]
KV cache buffer managing DMA, LDS staging, and register tiles.
Handles the full data path: DRAM -> LDS (shared memory) -> registers.
SMEM is navigated via .tile() on a strided parent TileTensor whose
(BK, BN) strides encode the blocked layout (num_repeats contiguous
BNΓBK blocks per stage, two stages). Stage selection and in-stage
block selection both happen via the tile column index β no pointer
arithmetic required. smem_mma_subtile is still used for V-operand
MMA sub-tiles which have mma_cols != BK.
When full_kv=True (depth<=256), each SMEM stage holds BN x smem_depth elements β the full tile. When full_kv=False (depth=512), each stage holds only BN x BK elements, and the caller iterates over BK blocks.
smem_depth defaults to depth. It exists for per-warp V buffers
whose depth is smaller than BK (e.g. depth_per_warp=16 with
BK=32): the SMEM layout stays valid with smem_depth = max(depth, BK) while depth keeps driving register-tile sizing and the column
count read by load_from_shared.
MMA register tiles (mma_tile) are TileTensor in LOCAL address space. TiledMmaOp (mma.mojo) handles SMEMβregister loads and MMA dispatch.
Fieldsβ
- βkv_mma_op (
KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].KVMmaOpType): - βsmem_tile (
KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].SmemParentType): - βkv_cache_iter (
KVCacheIterator[kv_t, BN, kv_num_heads, depth, cache_depth, head_dim_offset]): - βwarp_id (
UInt32):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
input_frag_sizeβ
comptime input_frag_size = ((KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_K * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_N) // WARP_SIZE)
KVMmaOpTypeβ
comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_mmas, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_mmas2, ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_tiles, BN, BK, transpose, swizzle]
MMA_Kβ
comptime MMA_K = mma_shape[2]
mma_layoutβ
comptime mma_layout = row_major[((KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_mmas * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_mmas2) * ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_tiles), KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].input_frag_size]()
MMA_Nβ
comptime MMA_N = mma_shape[1]
MMATileTypeβ
comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
num_k_mmas2β
comptime num_k_mmas2 = ceildiv(BK, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_K)
num_k_tilesβ
comptime num_k_tiles = ceildiv(depth if transpose else WN, BK)
num_mmasβ
comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_N)
num_repeatsβ
comptime num_repeats = (smem_depth // BK)
simd_widthβ
comptime simd_width = simd_width_of[kv_t.dtype]()
smem_colsβ
comptime smem_cols = smem_depth if full_kv else BK
smem_stage_sizeβ
comptime smem_stage_size = (BN * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].smem_cols)
SmemParentTypeβ
comptime SmemParentType = TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
vm_instrs_per_loadβ
comptime vm_instrs_per_load = SIMD((ceildiv(((BN // 32) * (smem_depth // BK) if full_kv else 1), (num_threads // WARP_SIZE)) * 2))
warp_tile_rowsβ
comptime warp_tile_rows = 32
wtile_dim0β
comptime wtile_dim0 = WN
wtile_dim1β
comptime wtile_dim1 = BK
Methodsβ
__init__β
__init__(out self, k_cache: kv_t, batch_idx: Int, head_idx: Int, smem_tile: TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], end: Int, warp_id: UInt32)
smem_block_tileβ
smem_block_tile[tile_rows: Int](self, tile_row: Int, block_col: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
Get a (tile_rows, BK) row-major sub-tile from SMEM.
tile_row indexes along BN (rows within a BNΓBK block), block_col indexes linearly across all BNΓBK blocks in both stages.
Returns:
TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
load_from_dramβ
load_from_dram[buffer_idx: Int](mut self)
get_mma_tileβ
get_mma_tile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Returns:
TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
mma_subtileβ
mma_subtile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Alias for get_mma_tile, kept for decode-call-site symmetry.
Returns:
TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
load_from_sharedβ
load_from_shared(self, buffer: Int)
load_from_shared[bk_tile: Int](self, buffer: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!