Mojo struct
SharedMemoryManager
struct SharedMemoryManager[shared_kv: Bool, full_kv: Bool, depth_padded: Bool, double_buffer: Bool, dtype: DType, BM: Int, BN: Int, BK: Int, depth: Int, token_gen: Bool]
Fields
- p_smem (
UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]): - k_smem (
UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]): - v_smem (
UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]):
Implemented traits
AnyType,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
accum_type
alias accum_type = get_accum_type[dtype]()
alignment
alias alignment = align_of[SIMD[dtype, simd_width_of[dtype]()]]()
k_smem_size
alias k_smem_size = ((BN * depth if full_kv else BK) * 2 if double_buffer else 1)
p_smem_size
alias p_smem_size = (BM * BN) if token_gen else 0
simd_width
alias simd_width = simd_width_of[dtype]()
v_smem_size
alias v_smem_size = ((BN if full_kv else BK * pad[dtype, depth, depth]() if depth_padded else depth) * 2 if double_buffer else 1)
Methods
__init__
__init__(out self)
get_k_ptr
get_k_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]
Returns:
get_v_ptr
get_v_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]
Returns:
get_p_ptr
get_p_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]
Returns:
get_warp_scratch_ptr
get_warp_scratch_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!