IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TMAProducerPipeline

struct TMAProducerPipeline[dtype: DType, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], is_k: Bool = True]

Unified producer pipeline for K and V TMA loading.

K loading (is_k=True): Can be staged (num_qk_stages chunks), uses k_major layout. V loading (is_k=False): Always complete (qk_stage=0), uses mn_major layout.

Fields​

  • ​pipeline (StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)]):
  • ​smem (TMAProducerPipeline[dtype, config, is_k].SMemType):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

bytes​

comptime bytes = TMAProducerPipeline[dtype, config, is_k].tile_bytes

elements​

comptime elements = TMAProducerPipeline[dtype, config, is_k].tile_layout.size()

elements_full​

comptime elements_full = (tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]().size() * config) if is_k else TMAProducerPipeline[dtype, config, is_k].elements

KPairType​

comptime KPairType = TMAProducerPipeline[dtype, config, is_k].PairType

num_qk_stages_effective​

comptime num_qk_stages_effective = config.num_qk_stages if is_k else Int(1)

PairType​

comptime PairType = TMADestination[dtype, tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]().size()]

SMemType​

comptime SMemType = UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

tile_bytes​

comptime tile_bytes = (tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]().size() * size_of[dtype]())

tile_layout​

comptime tile_layout = tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]()

Methods​

__init__​

def __init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

def __init__(pipeline: StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get_smem​

def get_smem[*, qk_stage: Int = Int(0)](self) -> Self.SMemType

Get smem pointer for current stage.

Returns:

Self.SMemType

get_tile​

def get_tile[*, qk_stage: Int = Int(0)](self) -> Self.PairType

Get TMA destination for this stage.

Returns:

Self.PairType

def get_tile[*, qk_stage: Int = Int(0)](self, e: Int32) -> Self.PairType

Get TMA destination with optional expect_bytes.

Returns:

Self.PairType

acquire​

def acquire[*, qk_stage: Int = Int(0)](self)

Wait for consumer to release the buffer.

commit_step​

def commit_step(mut self)

Step the pipeline. Commit is handled by tma_op.async_copy.

get_k_smem​

def get_k_smem[*, qk_stage: Int](self) -> Self.SMemType

Returns:

Self.SMemType

get_k​

def get_k[*, qk_stage: Int](self) -> Self.PairType

Returns:

Self.PairType

def get_k[*, qk_stage: Int](self, e: Int32) -> Self.PairType

Returns:

Self.PairType

acquire_k​

def acquire_k[*, qk_stage: Int](self)

get_v_smem​

def get_v_smem(self) -> Self.SMemType

Returns:

Self.SMemType

get_v​

def get_v(self, e: Int32) -> Self.PairType

Returns:

Self.PairType

acquire_v​

def acquire_v(self)