Skip to main content

Mojo struct

TMAConsumerPipeline

struct TMAConsumerPipeline[dtype: DType, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], is_k: Bool = True]

Unified consumer pipeline for K and V TMA consumption.

K consumption (is_k=True): Uses k_major layout, supports staged qk_stages. V consumption (is_k=False): Uses mn_major layout, always uses qk_stage=0.

This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).

We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn

Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation.

Fields​

  • ​pipeline (StagedPipeline[config.num_kv_stages, TMAConsumerPipeline[dtype, config, is_k].num_qk_stages_effective]):
  • ​smem_desc (MMASmemDescriptorPair):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

BK​

comptime BK = config.BK0 if is_k else config.BK1

BMN​

comptime BMN = config.k_rows_per_cta() if is_k else config.v_cols_per_cta()

full_kv_bytes​

comptime full_kv_bytes = (((config.k_rows_per_cta() * config) * size_of[dtype]()) + ((config.k_rows_per_cta() * config.rope_depth()) * config.rope_dtype_size)) if is_k else ((config * config.v_cols_per_cta()) * size_of[dtype]())

is_k_major​

comptime is_k_major = is_k

num_qk_stages_effective​

comptime num_qk_stages_effective = config.num_qk_stages if is_k else 1

staged_k_bytes​

comptime staged_k_bytes = ((config.k_rows_per_cta() * config) * size_of[dtype]())

Methods​

__init__​

__init__(pipeline: StagedPipeline[config.num_kv_stages, TMAConsumerPipeline[dtype, config, is_k].num_qk_stages_effective], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get​

get(self) -> MMASmemDescriptorPair

Get smem descriptor for current stage.

Returns:

MMASmemDescriptorPair

wait​

wait[*, qk_stage: Int = 0](self)

Wait for tile from producer.

release​

release[*, qk_stage: Int = 0](mut self, e: Int32)

Release buffer after consuming.

get_k​

get_k(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_k​

wait_k[*, qk_stage: Int = (config - 1)](mut self)

Wait on K stage from the producer.

release_k​

release_k[*, qk_stage: Int = (config - 1)](mut self, e: Int32)

Release K buffer after consuming this stage.

get_v​

get_v(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_v​

wait_v(self)

Wait for V tile.

release_v​

release_v(mut self, e: Int32)

Release V buffer after consuming.