Mojo struct
StagedPipeline
@register_passable(trivial)
struct StagedPipeline[num_kv_stages: Int, num_qk_stages: Int = 1]
Unified pipeline for K, V, and KV tile barrier management.
num_kv_stages refers to how many KV tile buffers we have for pipelining.
num_qk_stages controls K loading staging for Q@K' MMA:
- K can be loaded in num_qk_stages chunks, allowing MMA to start earlier
- V always uses qk_stages=1 (complete tile required)
Total stages = num_kv_stages * num_qk_stages.
Fields
- mbar (
MBarType): - state (
PipelineState[num_kv_stages]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
num_stages
comptime num_stages = (num_kv_stages * num_qk_stages)
Methods
__init__
__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
producer_mbar
producer_mbar[qk_stage: Int = 0](self) -> MBarType
Returns:
MBarType
consumer_mbar
consumer_mbar[qk_stage: Int = 0](self, idx: UInt32) -> MBarType
Returns:
MBarType
consumer_mbar[qk_stage: Int = 0](self) -> MBarType
Returns:
MBarType
producer_acquire
producer_acquire[qk_stage: Int = (num_qk_stages - 1)](self)
Wait until consumer has released the buffer for this stage.
consumer_wait
consumer_wait[qk_stage: Int = (num_qk_stages - 1)](self)
Wait for producer to complete this stage.
consumer_release
consumer_release[qk_stage: Int = (num_qk_stages - 1)](mut self, e: Int32)
Release the buffer after consuming this stage.
num_mbars
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!