For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
KVBuffer
struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[Int(3)], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool, full_kv: Bool = True, cache_depth: Int = depth, head_dim_offset: Int = Int(0), reg_chunk_depth: Int = depth, smem_depth: Int = depth, bk_smem: Int = BK]
KV cache buffer managing DMA, LDS staging, and register tiles.
Handles the full data path: DRAM -> LDS (shared memory) -> registers.
SMEM is navigated via .tile() on a strided parent TileTensor whose
(BK, BN) strides encode the blocked layout (num_repeats contiguous
BNΓBK blocks per stage, two stages). Stage selection and in-stage
block selection both happen via the tile column index β no pointer
arithmetic required. smem_mma_subtile is still used for V-operand
MMA sub-tiles which have mma_cols != BK.
When full_kv=True (depth<=256), each SMEM stage holds BN x smem_depth elements β the full tile. When full_kv=False (depth=512), each stage holds only BN x BK elements, and the caller iterates over BK blocks.
smem_depth defaults to depth. It exists for per-warp V buffers
whose depth is smaller than BK (e.g. depth_per_warp=16 with
BK=32): the SMEM layout stays valid with smem_depth = max(depth, BK) while depth keeps driving register-tile sizing and the column
count read by load_from_shared.
MMA register tiles (mma_tile) are TileTensor in LOCAL address space. TiledMmaOp (mma.mojo) handles SMEMβregister loads and MMA dispatch.
Fieldsβ
- βkv_mma_op (
KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].KVMmaOpType): - βsmem_tile (
KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].SmemParentType): - βkv_cache_iter (
KVCacheIterator[kv_t, BN, kv_num_heads, depth, cache_depth, head_dim_offset]): - βwarp_id (
UInt32):
Implemented traitsβ
comptime membersβ
input_frag_sizeβ
comptime input_frag_size = (Int((mul mma_shape[Int(2)], mma_shape[Int(1)])) // _resolve_warp_size())
KVMmaOpTypeβ
comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, ceildiv(WN if transpose else depth, mma_shape[Int(1)]), ceildiv(BK, mma_shape[Int(2)]), ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_k_tiles, BN, BK, transpose, swizzle]
MMA_Kβ
comptime MMA_K = mma_shape[Int(2)]
mma_layoutβ
comptime mma_layout = row_major[(Int((mul ceildiv(WN if transpose else depth, mma_shape[Int(1)]), ceildiv(BK, mma_shape[Int(2)]))) * ceildiv(reg_chunk_depth, BK) if transpose else KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].num_k_tiles), (Int((mul mma_shape[Int(2)], mma_shape[Int(1)])) // _resolve_warp_size())]()
MMA_Nβ
comptime MMA_N = mma_shape[Int(1)]
MMATileTypeβ
comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
num_k_mmas2β
comptime num_k_mmas2 = ceildiv(BK, mma_shape[Int(2)])
num_k_tilesβ
comptime num_k_tiles = ceildiv(depth if transpose else WN, BK)
num_mmasβ
comptime num_mmas = ceildiv(WN if transpose else depth, mma_shape[Int(1)])
num_repeatsβ
comptime num_repeats = (smem_depth // bk_smem)
simd_widthβ
comptime simd_width = simd_width_of[kv_t.dtype]()
smem_colsβ
comptime smem_cols = smem_depth if full_kv else bk_smem
smem_stage_sizeβ
comptime smem_stage_size = (BN * KVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose, full_kv, cache_depth, head_dim_offset, reg_chunk_depth, smem_depth, bk_smem].smem_cols)
SmemParentTypeβ
comptime SmemParentType = TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
vm_instrs_per_loadβ
comptime vm_instrs_per_load = SIMD(Int((mul ceildiv(Int((mul (BN // Int(32)), (smem_depth // bk_smem) if full_kv else Int(1))), (num_threads // _resolve_warp_size())), 2)))
warp_tile_rowsβ
comptime warp_tile_rows = 32
wtile_dim0β
comptime wtile_dim0 = WN
wtile_dim1β
comptime wtile_dim1 = BK
Methodsβ
__init__β
def __init__(out self, k_cache: kv_t, batch_idx: Int, head_idx: Int, smem_tile: TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], end: Int, warp_id: UInt32)
smem_block_tileβ
def smem_block_tile[tile_rows: Int](self, tile_row: Int, block_col: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
Get a (tile_rows, bk_smem) row-major sub-tile from SMEM.
tile_row indexes along BN (rows within a BNΓbk_smem block), block_col indexes linearly across all BNΓbk_smem blocks in both stages.
Returns:
TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
load_from_dramβ
def load_from_dram[buffer_idx: Int](mut self)
get_mma_tileβ
def get_mma_tile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Returns:
TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
mma_subtileβ
def mma_subtile[k_mma_tile_idx: Int, bk_tile_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Alias for get_mma_tile, kept for decode-call-site symmetry.
Returns:
TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
zero_partial_tile_padβ
def zero_partial_tile_pad(self)
Register-side zero for the OOB tail of the partial K-tile.
When depth % BK != 0, the last K-tile (i = depth // BK) spans
BK K-positions but only valid_cols = depth - i*BK are valid;
the trailing BK - valid_cols are pad and must read as zero.
Per-lane K-fragment layout: the input_frag_size elements per
lane interleave across MFMA-K such that the LOWER
valid_per_lane = input_frag_size * valid_cols / BK elements
correspond to the valid K-range and the UPPER
input_frag_size - valid_per_lane elements correspond to the
pad. Zero the upper portion per lane.
The reference pre-zeros half of the partial-tile K-fragment dwords once and reuses; we re-zero each K LDS load because the LDS loader fills the whole reg tile.
A no-op when depth % BK == 0.
NOTE: keep in sync with QRegisterBuffer.__init__'s partial-tile
zero in buffers.mojo. Both sites share the upper-half-is-pad
assumption (asserted below); a future config that violates it
(valid_cols > BK/2) needs a different zero pattern in both.
load_from_sharedβ
def load_from_shared(self, buffer: Int)
def load_from_shared[bk_tile: Int](self, buffer: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!