Mojo struct
TMAProducerPipeline
@register_passable(trivial)
struct TMAProducerPipeline[dtype: DType, config: FA4Config, is_k: Bool = True]
Unified producer pipeline for K and V TMA loading.
K loading (is_k=True): Can be staged (num_qk_stages chunks), uses k_major layout. V loading (is_k=False): Always complete (qk_stage=0), uses mn_major layout.
Fields
- pipeline (
StagedPipeline[config.num_kv_stages, TMAProducerPipeline[dtype, config, is_k].num_qk_stages_effective]): - smem (
TMAProducerPipeline[dtype, config, is_k].SMemType):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
bytes
comptime bytes = TMAProducerPipeline[dtype, config, is_k].tile_bytes
elements
comptime elements = TMAProducerPipeline[dtype, config, is_k].tile_layout.size()
elements_full
comptime elements_full = (TMAProducerPipeline[dtype, config, is_k].elements * config) if is_k else TMAProducerPipeline[dtype, config, is_k].elements
KPairType
comptime KPairType = TMAProducerPipeline[dtype, config, is_k].PairType
num_qk_stages_effective
comptime num_qk_stages_effective = config.num_qk_stages if is_k else 1
PairType
comptime PairType = TMADestination[dtype, TMAProducerPipeline[dtype, config, is_k].tile_layout]
SMemType
comptime SMemType = UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
tile_bytes
comptime tile_bytes = (TMAProducerPipeline[dtype, config, is_k].elements * size_of[dtype]())
tile_layout
comptime tile_layout = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]() if is_k else tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()
TileType
comptime TileType = LayoutTensor[dtype, TMAProducerPipeline[dtype, config, is_k].tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
Methods
__init__
__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
__init__(pipeline: StagedPipeline[config.num_kv_stages, TMAProducerPipeline[dtype, config, is_k].num_qk_stages_effective], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
get_smem
get_smem[*, qk_stage: Int = 0](self) -> TMAProducerPipeline[dtype, config, is_k].SMemType
Get smem pointer for current stage.
Returns:
TMAProducerPipeline
get_tile
get_tile[*, qk_stage: Int = 0](self) -> TMAProducerPipeline[dtype, config, is_k].PairType
Get TMA destination for this stage.
Returns:
TMAProducerPipeline
get_tile[*, qk_stage: Int = 0](self, e: Int32) -> TMAProducerPipeline[dtype, config, is_k].PairType
Get TMA destination with optional expect_bytes.
Returns:
TMAProducerPipeline
acquire
acquire[*, qk_stage: Int = 0](self)
Wait for consumer to release the buffer.
commit_step
commit_step(mut self)
Step the pipeline. Commit is handled by tma_op.async_copy.
get_k_smem
get_k_smem[*, qk_stage: Int](self) -> TMAProducerPipeline[dtype, config, is_k].SMemType
Returns:
TMAProducerPipeline
get_k
get_k[*, qk_stage: Int](self) -> TMAProducerPipeline[dtype, config, is_k].PairType
Returns:
TMAProducerPipeline
get_k[*, qk_stage: Int](self, e: Int32) -> TMAProducerPipeline[dtype, config, is_k].PairType
Returns:
TMAProducerPipeline
acquire_k
acquire_k[*, qk_stage: Int](self)
get_v_smem
get_v_smem(self) -> TMAProducerPipeline[dtype, config, is_k].SMemType
Returns:
TMAProducerPipeline
get_v
get_v(self, e: Int32) -> TMAProducerPipeline[dtype, config, is_k].PairType
Returns:
TMAProducerPipeline
acquire_v
acquire_v(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!