For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TMAProducerPipeline
struct TMAProducerPipeline[dtype: DType, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], is_k: Bool = True]
Unified producer pipeline for K and V TMA loading.
K loading (is_k=True): Can be staged (num_qk_stages chunks), uses k_major layout. V loading (is_k=False): Always complete (qk_stage=0), uses mn_major layout.
Fieldsβ
- βpipeline (
StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)]): - βsmem (
TMAProducerPipeline[dtype, config, is_k].SMemType):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
bytesβ
comptime bytes = TMAProducerPipeline[dtype, config, is_k].tile_bytes
elementsβ
comptime elements = TMAProducerPipeline[dtype, config, is_k].tile_layout.size()
elements_fullβ
comptime elements_full = (tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]().size() * config) if is_k else TMAProducerPipeline[dtype, config, is_k].elements
KPairTypeβ
comptime KPairType = TMAProducerPipeline[dtype, config, is_k].PairType
num_qk_stages_effectiveβ
comptime num_qk_stages_effective = config.num_qk_stages if is_k else Int(1)
PairTypeβ
comptime PairType = TMADestination[dtype, tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]().size()]
SMemTypeβ
comptime SMemType = UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
tile_bytesβ
comptime tile_bytes = (tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]().size() * size_of[dtype]())
tile_layoutβ
comptime tile_layout = tile_layout_k_major[dtype, config.k_rows_per_cta(), config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.v_cols_per_cta(), config.BK1, config.swizzle_mode]()
Methodsβ
__init__β
def __init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
def __init__(pipeline: StagedPipeline[config.num_kv_stages, config.num_qk_stages if is_k else Int(1)], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
get_smemβ
def get_smem[*, qk_stage: Int = Int(0)](self) -> Self.SMemType
Get smem pointer for current stage.
Returns:
Self.SMemType
get_tileβ
def get_tile[*, qk_stage: Int = Int(0)](self) -> Self.PairType
Get TMA destination for this stage.
Returns:
Self.PairType
def get_tile[*, qk_stage: Int = Int(0)](self, e: Int32) -> Self.PairType
Get TMA destination with optional expect_bytes.
Returns:
Self.PairType
acquireβ
def acquire[*, qk_stage: Int = Int(0)](self)
Wait for consumer to release the buffer.
commit_stepβ
def commit_step(mut self)
Step the pipeline. Commit is handled by tma_op.async_copy.
get_k_smemβ
def get_k_smem[*, qk_stage: Int](self) -> Self.SMemType
Returns:
Self.SMemType
get_kβ
def get_k[*, qk_stage: Int](self) -> Self.PairType
Returns:
Self.PairType
def get_k[*, qk_stage: Int](self, e: Int32) -> Self.PairType
Returns:
Self.PairType
acquire_kβ
def acquire_k[*, qk_stage: Int](self)
get_v_smemβ
def get_v_smem(self) -> Self.SMemType
Returns:
Self.SMemType
get_vβ
def get_v(self, e: Int32) -> Self.PairType
Returns:
Self.PairType
acquire_vβ
def acquire_v(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!