Skip to main content

Mojo struct

StagedPipeline

@register_passable(trivial) struct StagedPipeline[num_kv_stages: Int, num_qk_stages: Int = 1]

Unified pipeline for K, V, and KV tile barrier management.

num_kv_stages refers to how many KV tile buffers we have for pipelining. num_qk_stages controls K loading staging for Q@K' MMA:

  • K can be loaded in num_qk_stages chunks, allowing MMA to start earlier
  • V always uses qk_stages=1 (complete tile required)

Total stages = num_kv_stages * num_qk_stages.

Fields

  • mbar (MBarType):
  • state (PipelineState[num_kv_stages]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

num_stages

comptime num_stages = (num_kv_stages * num_qk_stages)

Methods

__init__

__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

producer_mbar

producer_mbar[qk_stage: Int = 0](self) -> MBarType

Returns:

MBarType

consumer_mbar

consumer_mbar[qk_stage: Int = 0](self, idx: UInt32) -> MBarType

Returns:

MBarType

consumer_mbar[qk_stage: Int = 0](self) -> MBarType

Returns:

MBarType

producer_acquire

producer_acquire[qk_stage: Int = (num_qk_stages - 1)](self)

Wait until consumer has released the buffer for this stage.

consumer_wait

consumer_wait[qk_stage: Int = (num_qk_stages - 1)](self)

Wait for producer to complete this stage.

consumer_release

consumer_release[qk_stage: Int = (num_qk_stages - 1)](mut self, e: Int32)

Release the buffer after consuming this stage.

num_mbars

static num_mbars() -> UInt32

Returns:

UInt32

Was this page helpful?