IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

DecodeKVBuffer

struct DecodeKVBuffer[dtype: DType, kv_tile_layout: TensorLayout, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]

Fields​

  • ​load_tile (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].LoadTile):
  • ​mma_tile (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MmaTile):
  • ​smem_tile (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SmemTile):
  • ​gmem_tile (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].GmemTileType):
  • ​reg_loader (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].RegLoaderType):
  • ​tile_idx (Int):
  • ​load_tile_id (Int):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

GmemTileType​

comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]

LoadTile​

comptime LoadTile = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

MMA_K​

comptime MMA_K = shape[2]

MMA_N​

comptime MMA_N = shape[1]

MmaTile​

comptime MmaTile = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_K * group_size))

num_mmas​

comptime num_mmas = ceildiv(config.wsize, DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_N)

num_warps_n​

comptime num_warps_n = (BN // WN)

RegLoaderType​

comptime RegLoaderType = RegTileLoader[dtype, row_major[((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) * DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) // config.btile_dim1) if token_gen else (num_threads // 4), (config.btile_dim1 // DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) if token_gen else 4](), num_threads]

simd_width​

comptime simd_width = simd_width_of[dtype]()

SmemTile​

comptime SmemTile = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

wtile_dim0​

comptime wtile_dim0 = config.wtile_dim0

wtile_dim1​

comptime wtile_dim1 = config.wtile_dim1

Methods​

__init__​

def __init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], smem_tile: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED])

load_from_dram​

def load_from_dram(mut self)

get_mma_tile​

def get_mma_tile(self) -> TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

Returns:

TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

copy_to_shared​

def copy_to_shared[tile_id: Int = 0](self)

load_from_shared​

def load_from_shared[k_mma: Int](self)