Mojo struct
KVBufferLDS
struct KVBufferLDS[dtype: DType, kv_tile_layout: TensorLayout, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, token_gen: Bool = False]
KV buffer with direct DRAM→LDS DMA (no register staging).
Serialized single-buffer variant for comparison with KVBufferImpl's register double-buffered approach. Each iteration:
- load_from_dram(): issues DRAM→LDS DMA via LdsTileLoader
- copy_to_shared(): s_waitcnt vmcnt(0) to wait for DMA
- barrier(): make SMEM visible
- load_from_shared(): SMEM→registers via TiledMmaOp.load_b
- MMA
- barrier()
No register staging buffer, no prefetch overlap between iterations. Requires warp_id for DMA work distribution.
Fields
- mma_tile (
KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].MmaTile): - smem_tile (
KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].SmemTile): - gmem_tile (
KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].GmemTileType): - lds_loader (
KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].LdsLoaderType): - tile_idx (
Int): - warp_id (
UInt32):
Implemented traits
AnyType,
ImplicitlyDestructible,
KVBuffer
comptime members
GmemTileType
comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]
LdsLoaderType
comptime LdsLoaderType = LdsTileLoader[dtype, swizzle]
MMA_K
comptime MMA_K = shape[2]
MMA_N
comptime MMA_N = shape[1]
mma_tile_layout
comptime mma_tile_layout = Layout.row_major(KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].num_mmas, KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].simd_width)
MmaTile
comptime MmaTile = TileTensor[dtype, Layout[ComptimeInt[KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].num_mmas], ComptimeInt[KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].simd_width], ComptimeInt[(ComptimeInt[KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].simd_width].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]], MutExternalOrigin, address_space=AddressSpace.LOCAL]
num_k_tiles
comptime num_k_tiles = ceildiv(BK, (KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].MMA_K * group_size))
num_mmas
comptime num_mmas = ceildiv(config.wsize, KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].MMA_N)
num_warps_n
comptime num_warps_n = (BN // WN)
simd_width
comptime simd_width = simd_width_of[dtype]()
SmemTile
comptime SmemTile = TileTensor[dtype, Layout[ComptimeInt[config.btile_dim0], ComptimeInt[config.btile_dim1], ComptimeInt[(ComptimeInt[config.btile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]
vm_instrs_per_load
comptime vm_instrs_per_load = SIMD((ceildiv(((config.btile_dim0 // 32) * max((config.btile_dim1 // 32), 1)), (num_threads // WARP_SIZE)) * 2))
warp_tile_rows
comptime warp_tile_rows = 32
wtile_dim0
comptime wtile_dim0 = config.wtile_dim0
wtile_dim1
comptime wtile_dim1 = config.wtile_dim1
Methods
__init__
__init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], warp_id: UInt32)
load_from_dram
load_from_dram(mut self)
Issue direct DRAM→LDS DMA distributed across warps.
get_mma_tile
get_mma_tile(self) -> LayoutTensor[dtype, KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
Returns:
copy_to_shared
copy_to_shared[tile_id: Int = 0](self)
Wait for outstanding DRAM→LDS DMA to complete.
load_from_shared
load_from_shared[k_mma: Int](self)
Load MMA operands from SMEM to registers.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!