For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TMAConsumerPipeline
struct TMAConsumerPipeline[dtype: DType, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], is_k: Bool = True]
Unified consumer pipeline for K and V TMA consumption.
K consumption (is_k=True): Uses k_major layout, supports staged qk_stages. V consumption (is_k=False): Uses mn_major layout, always uses qk_stage=0.
This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations).
We consume/produce in the following order: 0. S0 <- Q0 @ Kn' 1. O1 <- O1 + P1 @ V{n-1} 2. S1 <- Q1 @ Kn' 3. O0 <- O0 + P0 @ Vn
Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation.
Fieldsβ
- βpipeline (
StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)]): - βsmem_desc (
MMASmemDescriptorPair):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
BKβ
comptime BK = config.BK0 if is_k else config.BK1
BMNβ
comptime BMN = config.k_rows_per_cta() if is_k else config.v_cols_per_cta()
full_kv_bytesβ
comptime full_kv_bytes = (Int((mul config.k_rows_per_cta(), size_of[dtype](), config.padded_ov_depth)) + Int((mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]()))) if is_k else (Int((mul config.v_cols_per_cta(), config.BN)) * size_of[dtype]())
is_k_majorβ
comptime is_k_major = is_k
num_qk_stages_effectiveβ
comptime num_qk_stages_effective = config.num_qk_stages if is_k else Int(1)
staged_k_bytesβ
comptime staged_k_bytes = (Int((mul config.k_rows_per_cta(), config.BK0)) * size_of[dtype]())
Methodsβ
__init__β
def __init__(pipeline: StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
def __init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
getβ
def get(self) -> MMASmemDescriptorPair
Get smem descriptor for current stage.
Returns:
waitβ
def wait[*, qk_stage: Int = Int(0)](self)
Wait for tile from producer.
releaseβ
def release[*, qk_stage: Int = Int(0)](mut self, e: Int32)
Release buffer after consuming.
get_kβ
wait_kβ
def wait_k[*, qk_stage: Int = Int((add config.num_qk_stages, -1))](mut self)
Wait on K stage from the producer.
release_kβ
def release_k[*, qk_stage: Int = Int((add config.num_qk_stages, -1))](mut self, e: Int32)
Release K buffer after consuming this stage.
get_vβ
wait_vβ
def wait_v(self)
Wait for V tile.
release_vβ
def release_v(mut self, e: Int32)
Release V buffer after consuming.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!