Skip to main content

Mojo struct

KVBufferImpl

struct KVBufferImpl[dtype: DType, kv_tile_layout: TensorLayout, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]

Fields

  • load_tile (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].LoadTile):
  • mma_tile (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MmaTile):
  • smem_tile (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SmemTile):
  • gmem_tile (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].GmemTileType):
  • reg_loader (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].RegLoaderType):
  • tile_idx (Int):
  • load_tile_id (Int):

Implemented traits

AnyType, ImplicitlyDestructible, KVBuffer

comptime members

GmemTileType

comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]

LoadTile

comptime LoadTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

MMA_K

comptime MMA_K = shape[2]

MMA_N

comptime MMA_N = shape[1]

mma_tile_layout

comptime mma_tile_layout = Layout.row_major(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)

MmaTile

comptime MmaTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles

comptime num_k_tiles = ceildiv(BK, (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_K * group_size))

num_mmas

comptime num_mmas = ceildiv(config.wsize, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_N)

num_warps_n

comptime num_warps_n = (BN // WN)

RegLoaderType

comptime RegLoaderType = RegTileLoader[dtype, ((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) // config.btile_dim1) if token_gen else (num_threads // 4), (config.btile_dim1 // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) if token_gen else 4, num_threads]

simd_width

comptime simd_width = simd_width_of[dtype]()

SmemTile

comptime SmemTile = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

wtile_dim0

comptime wtile_dim0 = config.wtile_dim0

wtile_dim1

comptime wtile_dim1 = config.wtile_dim1

Methods

__init__

__init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

load_from_dram

load_from_dram(mut self)

get_mma_tile

get_mma_tile(self) -> LayoutTensor[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

Returns:

LayoutTensor

copy_to_shared

copy_to_shared[tile_id: Int = 0](self)

load_from_shared

load_from_shared[k_mma: Int](self)

Was this page helpful?