IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

DecodeStreamingKVBuffer

struct DecodeStreamingKVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[Int(3)], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]

Streaming-decode KV buffer: single-buffer SMEM staging with per-strip DMA.

Unlike KVBuffer, this takes an external DRAM tile per outer-loop iteration and loads BK-wide strips one at a time.

K (transpose=True): BN x BK SMEM, column strips from BN x depth. V (transpose=False): BK x depth SMEM (blocked BK x BK), row strips.

Fields​

  • ​kv_mma_op (DecodeStreamingKVBuffer[mma_shape, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].KVMmaOpType):
  • ​smem_ptr (UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • ​warp_id (UInt32):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

input_frag_size​

comptime input_frag_size = (Int((mul mma_shape[Int(2)], mma_shape[Int(1)])) // _resolve_warp_size())

KVMmaOpType​

comptime KVMmaOpType = KVMmaOp[kv_t.dtype, mma_shape, ceildiv(WN if transpose else (depth // (BN // WN)), mma_shape[Int(1)]), ceildiv(BK, mma_shape[Int(2)]), Int(1), BN, BK, transpose, swizzle]

MMA_K​

comptime MMA_K = mma_shape[Int(2)]

mma_layout​

comptime mma_layout = row_major[(ceildiv(WN if transpose else (depth // (BN // WN)), mma_shape[Int(1)]) * ceildiv(BK, mma_shape[Int(2)])), (Int((mul mma_shape[Int(2)], mma_shape[Int(1)])) // _resolve_warp_size())]()

MMA_N​

comptime MMA_N = mma_shape[Int(1)]

MMATileType​

comptime MMATileType = TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

num_k_mmas2​

comptime num_k_mmas2 = ceildiv(BK, mma_shape[Int(2)])

num_mmas​

comptime num_mmas = ceildiv(WN if transpose else (depth // (BN // WN)), mma_shape[Int(1)])

simd_width​

comptime simd_width = simd_width_of[kv_t.dtype]()

warp_tile_rows​

comptime warp_tile_rows = 32

Methods​

__init__​

def __init__(out self, cache: kv_t, batch_idx: Int, head_idx: Int, smem_ptr: UnsafePointer[Scalar[kv_t.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], num_keys: Int, warp_id: UInt32)

k_smem_block_tile​

def k_smem_block_tile[tile_rows: Int](self, tile_row: Int) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Get a (tile_rows, BK) row-major sub-tile from the K SMEM.

Single-stage K SMEM has one BNΓ—BK block, so only the row index along BN varies.

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

load_from_dram​

def load_from_dram[strip_idx: Int](self, gmem_tile: TileTensor[kv_t.dtype, Storage=gmem_tile.Storage, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size])

Load one BK-wide strip from an external DRAM tile to SMEM.

K (transpose=True): columns [strip*BK, (strip+1)BK] from BN x depth. V (transpose=False): rows [stripBK, (strip+1)*BK] from BN x depth.

load_from_shared​

def load_from_shared(self)

Load from SMEM to MMA registers.

get_mma_tile​

def get_mma_tile[k_mma_idx: Int](self) -> TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Get register tile for one k_mma group within the single strip.

Returns:

TileTensor[kv_t.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]