Skip to main content

@register_passable(trivial) struct KVConsumerPipeline[dtype: DType, config: FA4Config]

Pipeline for managing the consumption of K and V. This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).

We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn

Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation. Oi + Pi @ V also depends on the correction, which is computed asynchronously with the softmax in a correction warpgroup (as soon as the softmax writes the correction factor).

wait on K0

S0 <- Q0 @ K0' S1 <- Q1 @ K0'

release K0

wait on V0

O0 <- P0 @ V0 for n in range(1,num_iters): # wait on Kn S0 <- Q0 @ Kn' O1 <- O1 + P1@V{n-1} # release V{n-1} S1 <- Q1 @ Kn' # release Kn # wait on Vn O0 <- P0 @ Vn O1 <- O1 + P1@V{num_iters-1}

wK0, rK0, wV0 wK1, rV0, rK1, wV1 wK2, rV1, rK2, wV2 wK3, rV2, rK3, wV3

wKn(state) wK0(0), rK0(0), wV0(1) wK1(2), rV0(1), rK1(2), wV1(3) wK2(4), rV1(3), rK2(4), wV2(5) wK3(6), rV2(5), rK3(6), wV3(7)

Rules: wK backs up and increments prior to waiting, except K0 rK increments after releasing rV uses backup

wK0(0), rK0(0), wV0(1) wK1(2), rV0(1), rK1(2), wV1(3) wK2(4), rV1(3), rK2(4), wV2(5) rV2(5)

Fields

  • kv_pipeline (KVPipeline[config.num_kv_stages, config.num_mma_stages]):
  • k_smem_descriptor (MMASmemDescriptor):
  • v_smem_descriptor (MMASmemDescriptor):
  • v_pipeline_release_index (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

full_kv_bytes

alias full_kv_bytes = ((config * config) * dtype.size_of())

mma_kv_bytes

alias mma_kv_bytes = ((config * config) * dtype.size_of())

Methods

__init__

__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]) -> Self

__init__(mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], smem: UnsafePointer[Scalar[dtype], address_space=AddressSpace(3)]) -> Self

init

init(self)

Only one of the producer or consumer should call init().

wait

wait[*, mma_stage: Int](self) -> UInt32

Wait on k from the producer, and return the k smem descriptor.

Returns:

UInt32

wait_k

wait_k[*, mma_stage: Int = (config - 1), pre_increment: Bool = True](mut self) -> MMASmemDescriptor

Wait on k from the producer, and return the k smem descriptor. If pre-increment is true.

Returns:

MMASmemDescriptor

wait_v

wait_v[*, mma_stage: Int = (config - 1)](self) -> MMASmemDescriptor

Returns:

MMASmemDescriptor

release_k

release_k[*, mma_stage: Int = (config - 1)](mut self)

Must call producer_commit on the tmem resource before calling consumer_release. release_k does increment the pipeline step.

release_v

release_v[*, mma_stage: Int = (config - 1)](self)

Must call producer_commit on the tmem resource before calling consumer_release. release_v does not increment the pipeline step.

Was this page helpful?