IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TMAConsumerPipeline

struct TMAConsumerPipeline[dtype: DType, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], is_k: Bool = True]

Unified consumer pipeline for K and V TMA consumption.

K consumption (is_k=True): Uses k_major layout, supports staged qk_stages. V consumption (is_k=False): Uses mn_major layout, always uses qk_stage=0.

This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).

We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn

Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation.

Fields​

  • ​pipeline (StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)]):
  • ​smem_desc (MMASmemDescriptorPair):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

BK​

comptime BK = config.BK0 if is_k else config.BK1

BMN​

comptime BMN = config.k_rows_per_cta() if is_k else config.v_cols_per_cta()

full_kv_bytes​

comptime full_kv_bytes = (Int((mul config.k_rows_per_cta(), size_of[dtype](), config.padded_ov_depth)) + Int((mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]()))) if is_k else (Int((mul config.v_cols_per_cta(), config.BN)) * size_of[dtype]())

is_k_major​

comptime is_k_major = is_k

num_qk_stages_effective​

comptime num_qk_stages_effective = config.num_qk_stages if is_k else Int(1)

staged_k_bytes​

comptime staged_k_bytes = (Int((mul config.k_rows_per_cta(), config.BK0)) * size_of[dtype]())

Methods​

__init__​

def __init__(pipeline: StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

def __init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get​

def get(self) -> MMASmemDescriptorPair

Get smem descriptor for current stage.

Returns:

MMASmemDescriptorPair

wait​

def wait[*, qk_stage: Int = Int(0)](self)

Wait for tile from producer.

release​

def release[*, qk_stage: Int = Int(0)](mut self, e: Int32)

Release buffer after consuming.

get_k​

def get_k(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_k​

def wait_k[*, qk_stage: Int = Int((add config.num_qk_stages, -1))](mut self)

Wait on K stage from the producer.

release_k​

def release_k[*, qk_stage: Int = Int((add config.num_qk_stages, -1))](mut self, e: Int32)

Release K buffer after consuming this stage.

get_v​

def get_v(self) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

wait_v​

def wait_v(self)

Wait for V tile.

release_v​

def release_v(mut self, e: Int32)

Release V buffer after consuming.