Skip to main content

Mojo struct

SharedMemoryManager

struct SharedMemoryManager[shared_kv: Bool, full_kv: Bool, depth_padded: Bool, double_buffer: Bool, dtype: DType, BM: Int, BN: Int, BK: Int, depth: Int, token_gen: Bool]

Fields

  • p_smem (UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]):
  • k_smem (UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]):
  • v_smem (UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = True

accum_type

alias accum_type = get_accum_type[dtype]()

alignment

alias alignment = align_of[SIMD[dtype, simd_width_of[dtype]()]]()

k_smem_size

alias k_smem_size = ((BN * depth if full_kv else BK) * 2 if double_buffer else 1)

p_smem_size

alias p_smem_size = (BM * BN) if token_gen else 0

simd_width

alias simd_width = simd_width_of[dtype]()

v_smem_size

alias v_smem_size = ((BN if full_kv else BK * pad[dtype, depth, depth]() if depth_padded else depth) * 2 if double_buffer else 1)

Methods

__init__

__init__(out self)

get_k_ptr

get_k_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]

Returns:

UnsafePointer

get_v_ptr

get_v_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]

Returns:

UnsafePointer

get_p_ptr

get_p_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]

Returns:

UnsafePointer

get_warp_scratch_ptr

get_warp_scratch_ptr[_dtype: DType](self) -> UnsafePointer[Scalar[_dtype], address_space=AddressSpace(3)]

Returns:

UnsafePointer

Was this page helpful?