IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TMAConsumerPipeline

struct TMAConsumerPipeline[dtype: DType, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], is_k: Bool = True]

Unified consumer pipeline for K and V TMA consumption.

K consumption (is_k=True): Uses k_major layout, supports staged qk_stages. V consumption (is_k=False): Uses mn_major layout, always uses qk_stage=0.

This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).

We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn

Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation.

Fields​

  • ​pipeline (StagedPipeline[config.num_kv_stages, TMAConsumerPipeline[dtype, config, is_k].num_qk_stages_effective]):
  • ​smem_desc (MMASmemDescriptorPair):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

BK​

comptime BK = config.BK0 if is_k else config.BK1

BMN​

comptime BMN = config.k_rows_per_cta() if is_k else config.v_cols_per_cta()

full_kv_bytes​

comptime full_kv_bytes = (((config.k_rows_per_cta() * config) * size_of[dtype]()) + ((config.k_rows_per_cta() * config.rope_depth()) * config.rope_dtype_size)) if is_k else ((config * config.v_cols_per_cta()) * size_of[dtype]())

is_k_major​

comptime is_k_major = is_k

num_qk_stages_effective​

comptime num_qk_stages_effective = config.num_qk_stages if is_k else 1

staged_k_bytes​

comptime staged_k_bytes = ((config.k_rows_per_cta() * config) * size_of[dtype]())

Methods​

__init__​

def __init__(pipeline: StagedPipeline[config.num_kv_stages, Self.num_qk_stages_effective], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

def __init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get​

def get(self) -> MMASmemDescriptorPair

Get smem descriptor for current stage.

Returns:

MMASmemDescriptorPair

wait​

def wait[*, qk_stage: Int = 0](self)

Wait for tile from producer.

release​

def release[*, qk_stage: Int = 0](mut self, e: Int32)

Release buffer after consuming.

get_k​

def get_k(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_k​

def wait_k[*, qk_stage: Int = (config - 1)](mut self)

Wait on K stage from the producer.

release_k​

def release_k[*, qk_stage: Int = (config - 1)](mut self, e: Int32)

Release K buffer after consuming this stage.

get_v​

def get_v(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_v​

def wait_v(self)

Wait for V tile.

release_v​

def release_v(mut self, e: Int32)

Release V buffer after consuming.