@register_passable(trivial)
struct KVConsumerPipeline[dtype: DType, config: FA4Config]
Pipeline for managing the consumption of K and V. This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).
We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn
Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation. Oi + Pi @ V also depends on the correction, which is computed asynchronously with the softmax in a correction warpgroup (as soon as the softmax writes the correction factor).
wait on K0
S0 <- Q0 @ K0' S1 <- Q1 @ K0'
release K0
wait on V0
O0 <- P0 @ V0 for n in range(1,num_iters): # wait on Kn S0 <- Q0 @ Kn' O1 <- O1 + P1@V{n-1} # release V{n-1} S1 <- Q1 @ Kn' # release Kn # wait on Vn O0 <- P0 @ Vn O1 <- O1 + P1@V{num_iters-1}
wK0, rK0, wV0 wK1, rV0, rK1, wV1 wK2, rV1, rK2, wV2 wK3, rV2, rK3, wV3
wKn(state) wK0(0), rK0(0), wV0(1) wK1(2), rV0(1), rK1(2), wV1(3) wK2(4), rV1(3), rK2(4), wV2(5) wK3(6), rV2(5), rK3(6), wV3(7)
Rules: wK backs up and increments prior to waiting, except K0 rK increments after releasing rV uses backup
wK0(0), rK0(0), wV0(1) wK1(2), rV0(1), rK1(2), wV1(3) wK2(4), rV1(3), rK2(4), wV2(5) rV2(5)
Fields
- kv_pipeline (
KVPipeline[config.num_kv_stages, config.num_mma_stages]): - k_smem_descriptor (
MMASmemDescriptor): - v_smem_descriptor (
MMASmemDescriptor): - v_pipeline_release_index (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
full_kv_bytes
alias full_kv_bytes = ((config * config) * dtype.size_of())
mma_kv_bytes
alias mma_kv_bytes = ((config * config) * dtype.size_of())
Methods
__init__
__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]) -> Self
__init__(mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], smem: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]) -> Self
init
init(self)
Only one of the producer or consumer should call init().
wait
wait[*, mma_stage: Int](self) -> UInt32
Wait on k from the producer, and return the k smem descriptor.
Returns:
wait_k
wait_k[*, mma_stage: Int = (config - 1), pre_increment: Bool = True](mut self) -> MMASmemDescriptor
Wait on k from the producer, and return the k smem descriptor. If pre-increment is true.
Returns:
wait_v
release_k
release_k[*, mma_stage: Int = (config - 1)](mut self)
Must call producer_commit on the tmem resource before calling consumer_release. release_k does increment the pipeline step.
release_v
release_v[*, mma_stage: Int = (config - 1)](self)
Must call producer_commit on the tmem resource before calling consumer_release. release_v does not increment the pipeline step.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!