Mojo struct
KVProducerPipeline
@register_passable(trivial)
struct KVProducerPipeline[dtype: DType, config: FA4Config]
Fields
- kv_pipeline (
KVPipeline[config.num_kv_stages, config.num_mma_stages]): - smem (
LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
KPairType
alias KPairType = TMADestination[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()]
KType
alias KType = LayoutTensor[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode](), MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
kv_bytes
alias kv_bytes = (KVProducerPipeline[dtype, config].kv_elements * size_of[dtype]())
kv_elements
alias kv_elements = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]().size()
SMemType
alias SMemType = LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]
VPairType
alias VPairType = TMADestination[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()]
VType
alias VType = LayoutTensor[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode](), MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
Methods
__init__
__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self
__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self
init
init(self)
Only one of the producer or consumer should call init().
get_kv_smem
get_kv_smem[*, mma_stage: Int](self) -> LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]
Returns:
get_k
get_k[*, mma_stage: Int, expect: Bool = True](self) -> TMADestination[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()]
Returns:
TMADestination
get_v
get_v[*, mma_stage: Int](self) -> TMADestination[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()]
Returns:
TMADestination
acquire_kv
acquire_kv[*, mma_stage: Int = (config - 1)](self)
commit_kv_step
commit_kv_step(mut self)
Step the kv pipeline. The does not perform the commit on the mbars; that should be handled by the tma_op.async_copy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!