Skip to main content

Mojo struct

KVPipeline

@register_passable(trivial) struct KVPipeline[num_kv_stages: Int, num_mma_stages: Int]

KVPipeline has num_kv_stages * num_mma_stages stages. num_kv_stages refers to how many K and V tiles we pipeline for performing the S = Q@K' and O += P@V MMAs. Each of these MMAs is broken up into num_mma_stages pipelined MMAs. We set step=False for all but the last MMA that completes the operation. An alternative implementation would separate the two, and potentially allow for more overall stages at the cost of slightly more bookkeeping.

Fields

  • mbar (UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)]):
  • state (PipelineState[num_kv_stages]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

num_stages

alias num_stages = (num_kv_stages * num_mma_stages)

Methods

__init__

__init__(mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)]) -> Self

init

init(self)

producer_mbar

producer_mbar[mma_stage: Int](self) -> UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)]

Returns:

UnsafePointer

consumer_mbar

consumer_mbar[mma_stage: Int](self, idx: UInt32) -> UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)]

Returns:

UnsafePointer

consumer_mbar[mma_stage: Int](self) -> UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)]

Returns:

UnsafePointer

producer_acquire

producer_acquire[mma_stage: Int = (num_mma_stages - 1)](self)

Returns the dynamic pipe idx.

consumer_wait

consumer_wait[mma_stage: Int = (num_mma_stages - 1)](self)

consumer_release

consumer_release[mma_stage: Int = (num_mma_stages - 1)](mut self)

num_mbars

static num_mbars() -> UInt32

Returns:

UInt32

Was this page helpful?