Skip to main content

Mojo struct

KVBuffer

struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool, full_kv: Bool = True, cache_depth: Int = depth, head_dim_offset: Int = 0, reg_chunk_depth: Int = depth, smem_depth: Int = depth]

KV cache buffer managing DMA, LDS staging, and register tiles.

Handles the full data path: DRAM -> LDS (shared memory) -> registers.

SMEM is navigated via .tile() on a strided parent TileTensor whose (BK, BN) strides encode the blocked layout (num_repeats contiguous BNΓ—BK blocks per stage, two stages). Stage selection and in-stage block selection both happen via the tile column index β€” no pointer arithmetic required. smem_mma_subtile is still used for V-operand MMA sub-tiles which have mma_cols != BK.

When full_kv=True (depth<=256), each SMEM stage holds BN x smem_depth elements β€” the full tile. When full_kv=False (depth=512), each stage holds only BN x BK elements, and the caller iterates over BK blocks.

smem_depth defaults to depth. It exists for per-warp V buffers whose depth is smaller than BK (e.g. depth_per_warp=16 with BK=32): the SMEM layout stays valid with smem_depth = max(depth, BK) while depth keeps driving register-tile sizing and the column count read by load_from_shared.

MMA register tiles (mma_tile) are TileTensor in LOCAL address space. TiledMmaOp (mma.mojo) handles SMEM→register loads and MMA dispatch.

Fields​

  • ​kv_mma_op (KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].KVMmaOpType):
  • ​smem_tile (KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].SmemParentType):
  • ​kv_cache_iter (KVCacheIterator[kv_t, BN, kv_num_heads, depth, cache_depth, head_dim_offset]):
  • ​warp_id (UInt32):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

input_frag_size​

comptime input_frag_size = ((KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_K * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_N) // WARP_SIZE)

KVMmaOpType​

comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_mmas, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_mmas2, ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_tiles, BN, BK, transpose, swizzle]

MMA_K​

comptime MMA_K = mma_shape[2]

mma_layout​

comptime mma_layout = row_major[((KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_mmas * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_mmas2) * ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].num_k_tiles), KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].input_frag_size]()

MMA_N​

comptime MMA_N = mma_shape[1]

MMATileType​

comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

num_k_mmas2​

comptime num_k_mmas2 = ceildiv(BK, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_K)

num_k_tiles​

comptime num_k_tiles = ceildiv(depth if transpose else WN, BK)

num_mmas​

comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].MMA_N)

num_repeats​

comptime num_repeats = (smem_depth // BK)

simd_width​

comptime simd_width = simd_width_of[kv_t.dtype]()

smem_cols​

comptime smem_cols = smem_depth if full_kv else BK

smem_stage_size​

comptime smem_stage_size = (BN * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth].smem_cols)

SmemParentType​

comptime SmemParentType = TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

vm_instrs_per_load​

comptime vm_instrs_per_load = SIMD((ceildiv(((BN // 32) * (smem_depth // BK) if full_kv else 1), (num_threads // WARP_SIZE)) * 2))

warp_tile_rows​

comptime warp_tile_rows = 32

wtile_dim0​

comptime wtile_dim0 = WN

wtile_dim1​

comptime wtile_dim1 = BK

Methods​

__init__​

__init__(out self, k_cache: kv_t, batch_idx: Int, head_idx: Int, smem_tile: TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], end: Int, warp_id: UInt32)

smem_block_tile​

smem_block_tile[tile_rows: Int](self, tile_row: Int, block_col: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Get a (tile_rows, BK) row-major sub-tile from SMEM.

tile_row indexes along BN (rows within a BNΓ—BK block), block_col indexes linearly across all BNΓ—BK blocks in both stages.

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

load_from_dram​

load_from_dram[buffer_idx: Int](mut self)

get_mma_tile​

get_mma_tile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

mma_subtile​

mma_subtile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Alias for get_mma_tile, kept for decode-call-site symmetry.

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

load_from_shared​

load_from_shared(self, buffer: Int)

load_from_shared[bk_tile: Int](self, buffer: Int)