Mojo struct
KVBufferImpl
struct KVBufferImpl[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, mut: Bool, dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, origin: Origin[mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]
Fields
- load_tile (
LayoutTensor[dtype, Layout.row_major(((num_stages * ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1))) * ceildiv(BK, (shape.__getitem__[3, DType.int64, Int](2) * group_size))), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]): - mma_tile (
LayoutTensor[dtype, Layout.row_major(ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1)), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]): - smem_iter (
LayoutTensorIter[dtype, blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]): - bounds (
Int): - load_tile_id (
Int): - global_iterator (
LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, config.btile_dim0, config.btile_dim1]()[0], origin, address_space=address_space, axis=config.iterator_axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, config.btile_dim0, config.btile_dim1]()]):
Implemented traits
AnyType,
KVBuffer,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
base_layout
alias base_layout = Layout.row_major(config.btile_dim0, simd_width_of[dtype]())
GlobalTensorType
alias GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]
GlobalTiledIteratorType
alias GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, config.btile_dim0, config.btile_dim1]()[0], origin, address_space=address_space, axis=config.iterator_axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, config.btile_dim0, config.btile_dim1]()]
LoadTileType
alias LoadTileType = LayoutTensor[dtype, Layout.row_major(((num_stages * ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1))) * ceildiv(BK, (shape.__getitem__[3, DType.int64, Int](2) * group_size))), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]
MMA_K
alias MMA_K = shape.__getitem__[3, DType.int64, Int](2)
MMA_N
alias MMA_N = shape.__getitem__[3, DType.int64, Int](1)
mma_tile_layout
alias mma_tile_layout = Layout.row_major(ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1)), simd_width_of[dtype]())
MMATileType
alias MMATileType = LayoutTensor[dtype, Layout.row_major(ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1)), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]
num_k_tiles
alias num_k_tiles = ceildiv(BK, (shape.__getitem__[3, DType.int64, Int](2) * group_size))
num_mmas
alias num_mmas = ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1))
num_repeats
alias num_repeats = (config.btile_dim1 // simd_width_of[dtype]())
num_warps_n
alias num_warps_n = (BN // WN)
SharedIterType
alias SharedIterType = LayoutTensorIter[dtype, blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]
SharedTileType
alias SharedTileType = LayoutTensor[dtype, blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1), MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_index_type(AddressSpace(3)), linear_idx_type=_get_index_type(AddressSpace(3))]
SharedWarpTileType
alias SharedWarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1), MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace(3)), _get_index_type(AddressSpace(3)), False, align_of[dtype](), config.wtile_dim0, config.wtile_dim1]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_index_type(AddressSpace(3)), linear_idx_type=_get_index_type(AddressSpace(3)), masked=_tile_is_masked[blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1), config.wtile_dim0, config.wtile_dim1]()]
simd_width
alias simd_width = simd_width_of[dtype]()
smem_layout
alias smem_layout = blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1)
thread_layout
alias thread_layout = Layout.row_major(((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // simd_width_of[dtype]())) * simd_width_of[dtype]()) // blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1).stride[0].value()), (blocked_product(Layout.row_major(config.btile_dim0, simd_width_of[dtype]()), Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]())), True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1).stride[0].value() // simd_width_of[dtype]())) if token_gen else Layout.row_major((num_threads // 4), 4)
tiler_layout
alias tiler_layout = Layout.row_major(1, (config.btile_dim1 // simd_width_of[dtype]()))
wtile_dim0
alias wtile_dim0 = config.wtile_dim0
wtile_dim1
alias wtile_dim1 = config.wtile_dim1
Methods
__init__
__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_b_rows: OptionalReg[Int], shared_ptr: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3), mut=mut, origin=origin])
get_dtype
load_from_dram
load_from_dram(mut self)
get_mma_tile
get_mma_tile(self) -> LayoutTensor[dtype, Layout.row_major(ceildiv(config.wsize, shape.__getitem__[3, DType.int64, Int](1)), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]
Returns:
copy_to_shared
copy_to_shared[tile_id: Int = 0](self)
load_from_shared
load_from_shared[k_mma: Int](self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!