Skip to main content

Mojo struct

KVProducerPipeline

@register_passable(trivial) struct KVProducerPipeline[dtype: DType, config: FA4Config]

Fields

  • kv_pipeline (KVPipeline[config.num_kv_stages, config.num_mma_stages]):
  • smem (UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

KPairType

alias KPairType = TMADestination[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()]

KType

alias KType = LayoutTensor[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode](), MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

kv_bytes

alias kv_bytes = (tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]().size() * dtype.size_of())

kv_elements

alias kv_elements = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]().size()

SMemType

alias SMemType = UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]

VPairType

alias VPairType = TMADestination[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()]

VType

alias VType = LayoutTensor[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode](), MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

Methods

__init__

__init__(mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], smem: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]) -> Self

__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]) -> Self

init

init(self)

Only one of the producer or consumer should call init().

get_kv_smem

get_kv_smem[*, mma_stage: Int](self) -> UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]

Returns:

UnsafePointer

get_k

get_k[*, mma_stage: Int, expect: Bool = True](self) -> TMADestination[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()]

Returns:

TMADestination

get_v

get_v[*, mma_stage: Int](self) -> TMADestination[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()]

Returns:

TMADestination

acquire_kv

acquire_kv[*, mma_stage: Int = (config - 1)](self)

commit_kv_step

commit_kv_step(mut self)

Step the kv pipeline. The does not perform the commit on the mbars; that should be handled by the tma_op.async_copy.

Was this page helpful?