Mojo struct
KVBuffer
struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], k_group_size: Int, swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]
Fields
- mma_tile (
LayoutTensor[kv_t.dtype, Layout.row_major((ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), simd_width_of[kv_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]): - smem_iter (
LayoutTensorIter[kv_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]): - kv_cache_iter (
KVCacheIterator[kv_t, BN, kv_num_heads, depth]): - buffer_idx (
Int):
Implemented traits
AnyType,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True if True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial
base_layout
alias base_layout = Layout.row_major(BN, BK)
MMA_K
alias MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_N
alias MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)
MMATileType
alias MMATileType = LayoutTensor[kv_t.dtype, Layout.row_major((ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), simd_width_of[kv_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]
num_k_mmas2
alias num_k_mmas2 = ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))
num_mmas
alias num_mmas = ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1))
num_repeats
alias num_repeats = (depth // BK)
SharedIterType
alias SharedIterType = LayoutTensorIter[kv_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]
SharedTileType
alias SharedTileType = LayoutTensor[kv_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_index_type(AddressSpace(3)), linear_idx_type=_get_index_type(AddressSpace(3))]
SharedWarpTileType
alias SharedWarpTileType = LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace(3)), _get_index_type(AddressSpace(3)), False, align_of[kv_t.dtype](), WN, BK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_index_type(AddressSpace(3)), linear_idx_type=_get_index_type(AddressSpace(3)), masked=_tile_is_masked[blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), WN, BK]()]
simd_width
alias simd_width = simd_width_of[kv_t.dtype]()
smem_layout
alias smem_layout = blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False)
tiler_layout
alias tiler_layout = Layout.row_major(1, (depth // BK))
wtile_dim0
alias wtile_dim0 = WN
wtile_dim1
alias wtile_dim1 = BK
Methods
__init__
__init__(out self, k_cache: kv_t, batch_idx: UInt, head_idx: UInt, shared_ptr: UnsafePointer[Scalar[kv_t.dtype], address_space=AddressSpace(3), mut=mut, origin=origin], end: UInt)
load_from_dram
get_mma_tile
get_mma_tile[k_mma_tile_idx: Int](self) -> LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major((ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), simd_width_of[kv_t.dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), simd_width_of[kv_t.dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major((ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), simd_width_of[kv_t.dtype]()), AddressSpace(5)), False, align_of[kv_t.dtype](), (Layout.row_major((ceildiv(WN if swizzle.__bool__[Swizzle]() else depth, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), simd_width_of[kv_t.dtype]()).shape[0].value() // ceildiv(BK, Int.__init__[Int]((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size)))), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5)]
Returns:
copy_to_shared
copy_to_shared(self)
load_from_shared
load_from_shared(self, buffer: UInt, bk_tile: UInt)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!