Skip to main content

Mojo struct

SharedMemoryManager

struct SharedMemoryManager[dtype: DType, BM: Int, BN: Int, BK: Int, depth: Int, num_rowwise_warps: Int, token_gen: Bool]

Fields

  • p_smem (UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]):
  • k_v_smem (UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]):

Implemented traits

AnyType, Defaultable, UnknownDestructibility

Methods

__init__

__init__(out self)

get_kv_ptr

get_kv_ptr[dtype: DType](self) -> UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]

Returns:

UnsafePointer

get_p_ptr

get_p_ptr(self) -> UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]

Returns:

UnsafePointer

get_k_iter

get_k_iter(self) -> LayoutTensorIter[dtype, row_major(BN, BK), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]

Returns:

LayoutTensorIter

get_v_iter

get_v_iter(self) -> LayoutTensorIter[dtype, row_major(BK, BN), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]

Returns:

LayoutTensorIter

get_p_iter

get_p_iter(self) -> LayoutTensorIter[dtype, row_major(BM, BK), MutableAnyOrigin, address_space=AddressSpace(3), circular=True]

Returns:

LayoutTensorIter

get_warp_scratch_tensor

get_warp_scratch_tensor(self) -> LayoutTensor[get_accum_type[::DType,::DType](), row_major((num_rowwise_warps * 2), BM), MutableAnyOrigin, address_space=AddressSpace(3)]

Returns:

LayoutTensor

Was this page helpful?