Skip to main content

Mojo struct

KVBuffer

struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], k_group_size: Int, swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]

Fields

  • mma_tile (LayoutTensor[kv_t.dtype, Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]):
  • smem_iter (LayoutTensorIter[kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]):
  • kv_cache_iter (KVCacheIterator[kv_t, BN, kv_num_heads, depth]):
  • buffer_idx (Int):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

comptime __del__is_trivial = True if True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial

base_layout

comptime base_layout = Layout.row_major(BN, BK)

MMA_K

comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_N

comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)

MMATileType

comptime MMATileType = LayoutTensor[kv_t.dtype, Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_mmas2

comptime num_k_mmas2 = ceildiv(BK, Int.__init__[Int]((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K * k_group_size)))

num_mmas

comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N)

num_repeats

comptime num_repeats = (depth // BK)

SharedIterType

comptime SharedIterType = LayoutTensorIter[kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]

SharedTileType

comptime SharedTileType = LayoutTensorIter[kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True].LayoutTensorType

SharedWarpTileType

comptime SharedWarpTileType = LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace.SHARED), _get_index_type(AddressSpace.SHARED), False, align_of[kv_t.dtype](), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim0, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim0, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim1]()]

simd_width

comptime simd_width = simd_width_of[kv_t.dtype]()

smem_layout

comptime smem_layout = blocked_product(KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].base_layout, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].tiler_layout, False)

tiler_layout

comptime tiler_layout = Layout.row_major(1, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_repeats)

wtile_dim0

comptime wtile_dim0 = WN

wtile_dim1

comptime wtile_dim1 = BK

Methods

__init__

__init__(out self, k_cache: kv_t, batch_idx: UInt, head_idx: UInt, shared_ptr: LegacyUnsafePointer[Scalar[kv_t.dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin], end: UInt)

load_from_dram

load_from_dram(mut self) -> Int

Returns:

Int

get_mma_tile

get_mma_tile[k_mma_tile_idx: Int](self) -> LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), False, align_of[kv_t.dtype](), (Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width).shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), 0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]

Returns:

LayoutTensor

copy_to_shared

copy_to_shared(self)

load_from_shared

load_from_shared(self, buffer: UInt, bk_tile: UInt)

Was this page helpful?