Mojo struct
DecodeKVBuffer
struct DecodeKVBuffer[dtype: DType, kv_tile_layout: TensorLayout, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]
Fields
- load_tile (
DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].LoadTile): - mma_tile (
DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MmaTile): - smem_tile (
DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SmemTile): - gmem_tile (
DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].GmemTileType): - reg_loader (
DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].RegLoaderType): - tile_idx (
Int): - load_tile_id (
Int):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
GmemTileType
comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]
LoadTile
comptime LoadTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
MMA_K
comptime MMA_K = shape[2]
MMA_N
comptime MMA_N = shape[1]
MmaTile
comptime MmaTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
num_k_tiles
comptime num_k_tiles = ceildiv(BK, (DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_K * group_size))
num_mmas
comptime num_mmas = ceildiv(config.wsize, DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_N)
num_warps_n
comptime num_warps_n = (BN // WN)
RegLoaderType
comptime RegLoaderType = RegTileLoader[dtype, row_major[((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) * DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) // config.btile_dim1) if token_gen else (num_threads // 4), (config.btile_dim1 // DecodeKVBuffer[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) if token_gen else 4](), num_threads]
simd_width
comptime simd_width = simd_width_of[dtype]()
SmemTile
comptime SmemTile = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
wtile_dim0
comptime wtile_dim0 = config.wtile_dim0
wtile_dim1
comptime wtile_dim1 = config.wtile_dim1
Methods
__init__
__init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], smem_tile: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED])
load_from_dram
load_from_dram(mut self)
get_mma_tile
get_mma_tile(self) -> TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
Returns:
copy_to_shared
copy_to_shared[tile_id: Int = 0](self)
load_from_shared
load_from_shared[k_mma: Int](self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!