Skip to main content

Mojo struct

KVBufferLDS

struct KVBufferLDS[dtype: DType, kv_tile_layout: TensorLayout, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, token_gen: Bool = False]

KV buffer with direct DRAM→LDS DMA (no register staging).

Serialized single-buffer variant for comparison with KVBufferImpl's register double-buffered approach. Each iteration:

  1. load_from_dram(): issues DRAM→LDS DMA via LdsTileLoader
  2. copy_to_shared(): s_waitcnt vmcnt(0) to wait for DMA
  3. barrier(): make SMEM visible
  4. load_from_shared(): SMEM→registers via TiledMmaOp.load_b
  5. MMA
  6. barrier()

No register staging buffer, no prefetch overlap between iterations. Requires warp_id for DMA work distribution.

Fields

  • mma_tile (KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].MmaTile):
  • smem_tile (KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].SmemTile):
  • gmem_tile (KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].GmemTileType):
  • lds_loader (KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].LdsLoaderType):
  • tile_idx (Int):
  • warp_id (UInt32):

Implemented traits

AnyType, ImplicitlyDestructible, KVBuffer

comptime members

GmemTileType

comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]

LdsLoaderType

comptime LdsLoaderType = LdsTileLoader[dtype, swizzle]

MMA_K

comptime MMA_K = shape[2]

MMA_N

comptime MMA_N = shape[1]

mma_tile_layout

comptime mma_tile_layout = Layout.row_major(KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].num_mmas, KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].simd_width)

MmaTile

comptime MmaTile = TileTensor[dtype, Layout[ComptimeInt[KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].num_mmas], ComptimeInt[KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].simd_width], ComptimeInt[(ComptimeInt[KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].simd_width].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]], MutExternalOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles

comptime num_k_tiles = ceildiv(BK, (KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].MMA_K * group_size))

num_mmas

comptime num_mmas = ceildiv(config.wsize, KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].MMA_N)

num_warps_n

comptime num_warps_n = (BN // WN)

simd_width

comptime simd_width = simd_width_of[dtype]()

SmemTile

comptime SmemTile = TileTensor[dtype, Layout[ComptimeInt[config.btile_dim0], ComptimeInt[config.btile_dim1], ComptimeInt[(ComptimeInt[config.btile_dim1].static_value * ComptimeInt[1].static_value)], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]

vm_instrs_per_load

comptime vm_instrs_per_load = SIMD((ceildiv(((config.btile_dim0 // 32) * max((config.btile_dim1 // 32), 1)), (num_threads // WARP_SIZE)) * 2))

warp_tile_rows

comptime warp_tile_rows = 32

wtile_dim0

comptime wtile_dim0 = config.wtile_dim0

wtile_dim1

comptime wtile_dim1 = config.wtile_dim1

Methods

__init__

__init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], warp_id: UInt32)

load_from_dram

load_from_dram(mut self)

Issue direct DRAM→LDS DMA distributed across warps.

get_mma_tile

get_mma_tile(self) -> LayoutTensor[dtype, KVBufferLDS[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

Returns:

LayoutTensor

copy_to_shared

copy_to_shared[tile_id: Int = 0](self)

Wait for outstanding DRAM→LDS DMA to complete.

load_from_shared

load_from_shared[k_mma: Int](self)

Load MMA operands from SMEM to registers.

Was this page helpful?