Mojo struct
DecodeStreamingKVBuffer
struct DecodeStreamingKVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]
Streaming-decode KV buffer: single-buffer SMEM staging with per-strip DMA.
Unlike KVBuffer, this takes an external DRAM tile per outer-loop iteration and loads BK-wide strips one at a time.
K (transpose=True): BN x BK SMEM, column strips from BN x depth. V (transpose=False): BK x depth SMEM (blocked BK x BK), row strips.
Fields
- kv_mma_op (
DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].KVMmaOpType): - smem_ptr (
UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]): - warp_id (
UInt32):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
input_frag_size
comptime input_frag_size = ((DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K * DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N) // WARP_SIZE)
KVMmaOpType
comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas, DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2, 1, BN, BK, transpose, swizzle]
MMA_K
comptime MMA_K = mma_shape[2]
mma_layout
comptime mma_layout = row_major[(DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].input_frag_size]()
MMA_N
comptime MMA_N = mma_shape[1]
MMATileType
comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
num_k_mmas2
comptime num_k_mmas2 = ceildiv(BK, DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K)
num_mmas
comptime num_mmas = ceildiv(WN if transpose else (depth // (BN // WN)), DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N)
simd_width
comptime simd_width = simd_width_of[kv_t.dtype]()
warp_tile_rows
comptime warp_tile_rows = 32
Methods
__init__
__init__(out self, cache: kv_t, batch_idx: Int, head_idx: Int, smem_ptr: UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], num_keys: Int, warp_id: UInt32)
k_smem_block_tile
k_smem_block_tile[tile_rows: Int](self, tile_row: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
Get a (tile_rows, BK) row-major sub-tile from the K SMEM.
Single-stage K SMEM has one BN×BK block, so only the row index along BN varies.
Returns:
load_from_dram
load_from_dram[strip_idx: Int](self, gmem_tile: TileTensor[kv_t.dtype, gmem_tile.LayoutType, gmem_tile.origin, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size])
Load one BK-wide strip from an external DRAM tile to SMEM.
K (transpose=True): columns [strip*BK, (strip+1)BK] from BN x depth. V (transpose=False): rows [stripBK, (strip+1)*BK] from BN x depth.
load_from_shared
load_from_shared(self)
Load from SMEM to MMA registers.
get_mma_tile
get_mma_tile[k_mma_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Get register tile for one k_mma group within the single strip.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!