Skip to main content

Mojo struct

TMAProducerPipeline

@register_passable(trivial) struct TMAProducerPipeline[dtype: DType, config: FA4Config, is_k: Bool = True]

Unified producer pipeline for K and V TMA loading.

K loading (is_k=True): Can be staged (num_qk_stages chunks), uses k_major layout. V loading (is_k=False): Always complete (qk_stage=0), uses mn_major layout.

Fields

  • pipeline (StagedPipeline[config.num_kv_stages, TMAProducerPipeline[dtype, config, is_k].num_qk_stages_effective]):
  • smem (TMAProducerPipeline[dtype, config, is_k].SMemType):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

bytes

comptime bytes = TMAProducerPipeline[dtype, config, is_k].tile_bytes

elements

comptime elements = TMAProducerPipeline[dtype, config, is_k].tile_layout.size()

elements_full

comptime elements_full = (TMAProducerPipeline[dtype, config, is_k].elements * config) if is_k else TMAProducerPipeline[dtype, config, is_k].elements

KPairType

comptime KPairType = TMAProducerPipeline[dtype, config, is_k].PairType

num_qk_stages_effective

comptime num_qk_stages_effective = config.num_qk_stages if is_k else 1

PairType

comptime PairType = TMADestination[dtype, TMAProducerPipeline[dtype, config, is_k].tile_layout]

SMemType

comptime SMemType = UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

tile_bytes

comptime tile_bytes = (TMAProducerPipeline[dtype, config, is_k].elements * size_of[dtype]())

tile_layout

comptime tile_layout = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()

TileType

comptime TileType = LayoutTensor[dtype, TMAProducerPipeline[dtype, config, is_k].tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

Methods

__init__

__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

__init__(pipeline: StagedPipeline[config.num_kv_stages, TMAProducerPipeline[dtype, config, is_k].num_qk_stages_effective], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get_smem

get_smem[*, qk_stage: Int = 0](self) -> TMAProducerPipeline[dtype, config, is_k].SMemType

Get smem pointer for current stage.

Returns:

TMAProducerPipeline

get_tile

get_tile[*, qk_stage: Int = 0](self) -> TMAProducerPipeline[dtype, config, is_k].PairType

Get TMA destination for this stage.

Returns:

TMAProducerPipeline

get_tile[*, qk_stage: Int = 0](self, e: Int32) -> TMAProducerPipeline[dtype, config, is_k].PairType

Get TMA destination with optional expect_bytes.

Returns:

TMAProducerPipeline

acquire

acquire[*, qk_stage: Int = 0](self)

Wait for consumer to release the buffer.

commit_step

commit_step(mut self)

Step the pipeline. Commit is handled by tma_op.async_copy.

get_k_smem

get_k_smem[*, qk_stage: Int](self) -> TMAProducerPipeline[dtype, config, is_k].SMemType

Returns:

TMAProducerPipeline

get_k

get_k[*, qk_stage: Int](self) -> TMAProducerPipeline[dtype, config, is_k].PairType

Returns:

TMAProducerPipeline

get_k[*, qk_stage: Int](self, e: Int32) -> TMAProducerPipeline[dtype, config, is_k].PairType

Returns:

TMAProducerPipeline

acquire_k

acquire_k[*, qk_stage: Int](self)

get_v_smem

get_v_smem(self) -> TMAProducerPipeline[dtype, config, is_k].SMemType

Returns:

TMAProducerPipeline

get_v

get_v(self, e: Int32) -> TMAProducerPipeline[dtype, config, is_k].PairType

Returns:

TMAProducerPipeline

acquire_v

acquire_v(self)

Was this page helpful?