Skip to main content

Mojo struct

TMAConsumerPipeline

@register_passable(trivial) struct TMAConsumerPipeline[dtype: DType, config: FA4Config, is_k: Bool = True]

Unified consumer pipeline for K and V TMA consumption.

K consumption (is_k=True): Uses k_major layout, supports staged qk_stages. V consumption (is_k=False): Uses mn_major layout, always uses qk_stage=0.

This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).

We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn

Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation.

Fields

  • pipeline (StagedPipeline[config.num_kv_stages, TMAConsumerPipeline[dtype, config, is_k].num_qk_stages_effective]):
  • smem_desc (MMASmemDescriptorPair):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

BK

comptime BK = config.BK0 if is_k else config.BK1

BMN

comptime BMN = config.BN if is_k else config.padded_depth

full_kv_bytes

comptime full_kv_bytes = ((config * config) * size_of[dtype]())

is_k_major

comptime is_k_major = is_k

num_qk_stages_effective

comptime num_qk_stages_effective = config.num_qk_stages if is_k else 1

staged_k_bytes

comptime staged_k_bytes = ((config * config) * size_of[dtype]())

Methods

__init__

__init__(pipeline: StagedPipeline[config.num_kv_stages, TMAConsumerPipeline[dtype, config, is_k].num_qk_stages_effective], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get

get(self) -> MMASmemDescriptorPair

Get smem descriptor for current stage.

Returns:

MMASmemDescriptorPair

wait

wait[*, qk_stage: Int = 0](self)

Wait for tile from producer.

release

release[*, qk_stage: Int = 0](mut self, e: Int32)

Release buffer after consuming.

get_k

get_k(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_k

wait_k[*, qk_stage: Int = (config - 1)](mut self)

Wait on K stage from the producer.

release_k

release_k[*, qk_stage: Int = (config - 1)](mut self, e: Int32)

Release K buffer after consuming this stage.

get_v

get_v(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_v

wait_v(self)

Wait for V tile.

release_v

release_v(mut self, e: Int32)

Release V buffer after consuming.

Was this page helpful?