Skip to main content

Mojo struct

KBuffer

struct KBuffer[k_t: MHAOperand, //, mma_shape: IndexList[3], k_group_size: Int, swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int]

Fields

  • mma_tile (LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]):
  • smem_iter (LayoutTensorIter[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]):
  • kv_cache_iter (KVCacheIterator[k_t, BN, kv_num_heads, depth]):
  • buffer_idx (Int):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = Int.__del__is_trivial if KVCacheIterator[k_t, BN, kv_num_heads, depth].__del__is_trivial if LayoutTensorIter[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True].__del__is_trivial if LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else LayoutTensorIter[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True].__del__is_trivial if LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else KVCacheIterator[k_t, BN, kv_num_heads, depth].__del__is_trivial if LayoutTensorIter[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True].__del__is_trivial if LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else LayoutTensorIter[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True].__del__is_trivial if LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial else LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial

base_layout

alias base_layout = Layout.row_major(BN, BK)

MMA_K

alias MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_N

alias MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)

MMATileType

alias MMATileType = LayoutTensor[k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]

num_k_mmas2

alias num_k_mmas2 = ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))

num_mmas

alias num_mmas = ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1))

num_repeats

alias num_repeats = (depth // BK)

SharedIterType

alias SharedIterType = LayoutTensorIter[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]

SharedTileType

alias SharedTileType = LayoutTensor[k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_index_type(AddressSpace(3)), linear_idx_type=_get_index_type(AddressSpace(3))]

SharedWarpTileType

alias SharedWarpTileType = LayoutTensor[k_t.dtype, LayoutTensor._compute_tile_layout[True, k_t.dtype, blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_index_type(AddressSpace(3)), _get_index_type(AddressSpace(3)), False, align_of[k_t.dtype](), WN, BK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_index_type(AddressSpace(3)), linear_idx_type=_get_index_type(AddressSpace(3)), masked=_tile_is_masked[blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False), WN, BK]()]

simd_width

alias simd_width = simd_width_of[k_t.dtype]()

smem_layout

alias smem_layout = blocked_product(Layout.row_major(BN, BK), Layout.row_major(1, (depth // BK)), False)

tiler_layout

alias tiler_layout = Layout.row_major(1, (depth // BK))

wtile_dim0

alias wtile_dim0 = WN

wtile_dim1

alias wtile_dim1 = BK

Methods

__init__

__init__(out self, k_cache: k_t, batch_idx: UInt, head_idx: UInt, shared_ptr: UnsafePointer[Scalar[k_t.dtype], address_space=AddressSpace(3), mut=mut, origin=origin], end: UInt)

load_from_dram

load_from_dram(mut self) -> Int

Returns:

Int

get_mma_tile

get_mma_tile[k_mma_tile_idx: Int](self) -> LayoutTensor[k_t.dtype, LayoutTensor._compute_tile_layout[True, k_t.dtype, Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()), AddressSpace(5)), False, align_of[k_t.dtype](), (Layout.row_major((ceildiv(WN, mma_shape.__getitem__[3, DType.int64, Int](1)) * ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), simd_width_of[k_t.dtype]()).shape[0].value[ComptimeOrigin]() // ceildiv(BK, Int(UInt((mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))))), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5)]

Returns:

LayoutTensor

copy_to_shared

copy_to_shared(self)

load_from_shared

load_from_shared[accum_type: DType, mma_input_type: DType, transpose_b: Bool](self, buffer: UInt, bk_tile: UInt)

Was this page helpful?