Skip to main content

Mojo struct

DecodeStreamingKVBuffer

struct DecodeStreamingKVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]

Streaming-decode KV buffer: single-buffer SMEM staging with per-strip DMA.

Unlike KVBuffer, this takes an external DRAM tile per outer-loop iteration and loads BK-wide strips one at a time.

K (transpose=True): BN x BK SMEM, column strips from BN x depth. V (transpose=False): BK x depth SMEM (blocked BK x BK), row strips.

Fields

  • kv_mma_op (DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].KVMmaOpType):
  • smem_ptr (UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • warp_id (UInt32):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

input_frag_size

comptime input_frag_size = ((DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K * DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N) // WARP_SIZE)

KVMmaOpType

comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas, DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2, 1, BN, BK, transpose, swizzle]

MMA_K

comptime MMA_K = mma_shape[2]

mma_layout

comptime mma_layout = row_major[(DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].input_frag_size]()

MMA_N

comptime MMA_N = mma_shape[1]

MMATileType

comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

num_k_mmas2

comptime num_k_mmas2 = ceildiv(BK, DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K)

num_mmas

comptime num_mmas = ceildiv(WN if transpose else (depth // (BN // WN)), DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N)

simd_width

comptime simd_width = simd_width_of[kv_t.dtype]()

warp_tile_rows

comptime warp_tile_rows = 32

Methods

__init__

__init__(out self, cache: kv_t, batch_idx: Int, head_idx: Int, smem_ptr: UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], num_keys: Int, warp_id: UInt32)

k_smem_block_tile

k_smem_block_tile[tile_rows: Int](self, tile_row: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Get a (tile_rows, BK) row-major sub-tile from the K SMEM.

Single-stage K SMEM has one BN×BK block, so only the row index along BN varies.

Returns:

TileTensor

load_from_dram

load_from_dram[strip_idx: Int](self, gmem_tile: TileTensor[kv_t.dtype, gmem_tile.LayoutType, gmem_tile.origin, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size])

Load one BK-wide strip from an external DRAM tile to SMEM.

K (transpose=True): columns [strip*BK, (strip+1)BK] from BN x depth. V (transpose=False): rows [stripBK, (strip+1)*BK] from BN x depth.

load_from_shared

load_from_shared(self)

Load from SMEM to MMA registers.

get_mma_tile

get_mma_tile[k_mma_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Get register tile for one k_mma group within the single strip.

Returns:

TileTensor

Was this page helpful?