Skip to main content

Mojo struct

Depth512MBars

struct Depth512MBars[num_kv_stages: Int]

Manages all mbarrier resources for depth=512 pair-CTA attention.

Parameters

  • num_kv_stages (Int): Number of fused KV pipeline buffer slots.

Fields

  • mbar_base (MBarType):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

C_consumer_offset

comptime C_consumer_offset = 5

C_producer_offset

comptime C_producer_offset = 4

KV_barriers

comptime KV_barriers = (2 * num_kv_stages)

KV_offset

comptime KV_offset = 10

num_high_count_barriers

comptime num_high_count_barriers = 6

O_mma_hi_offset

comptime O_mma_hi_offset = 9

O_mma_lo_offset

comptime O_mma_lo_offset = 8

PO_hi_offset

comptime PO_hi_offset = 1

PO_lo_offset

comptime PO_lo_offset = 0

S_even_consumer_offset

comptime S_even_consumer_offset = 2

S_even_producer_offset

comptime S_even_producer_offset = 6

S_odd_consumer_offset

comptime S_odd_consumer_offset = 3

S_odd_producer_offset

comptime S_odd_producer_offset = 7

size

comptime size = (10 + Depth512MBars[num_kv_stages].KV_barriers)

SPipelineConsumer

comptime SPipelineConsumer = RolePipeline[1, False]

SPipelineProducer

comptime SPipelineProducer = RolePipeline[1, cta_group=2]

Methods

__init__

__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

Wrap a base mbarrier pointer from Depth512AttentionSMem.mbar_base().

init

init(self, *, lane_idx: Int32)

Initialize all barriers. One call per warp lane (lane_idx = tid % 32).

With num_kv_stages <= 11, total size <= 32 = WARP_SIZE, so a single wave covers all barriers.

producer_s_even

producer_s_even(self) -> Depth512MBars[num_kv_stages].SPipelineProducer

MMA warp: acquire waits on S_even consumer (buffer free), commit arrives at S_even producer (S written to TMEM).

Returns:

Depth512MBars

producer_s_odd

producer_s_odd(self) -> Depth512MBars[num_kv_stages].SPipelineProducer

MMA warp: acquire waits on S_odd consumer (buffer free), commit arrives at S_odd producer (S written to TMEM).

Returns:

Depth512MBars

consumer_s_even

consumer_s_even(self) -> Depth512MBars[num_kv_stages].SPipelineConsumer

Softmax warp: wait on S_even producer (S ready in TMEM), release arrives at S_even consumer (S loaded, buffer free).

Returns:

Depth512MBars

consumer_s_odd

consumer_s_odd(self) -> Depth512MBars[num_kv_stages].SPipelineConsumer

Softmax warp: wait on S_odd producer (S ready in TMEM), release arrives at S_odd consumer (S loaded, buffer free).

Returns:

Depth512MBars

producer_c

producer_c(self) -> RolePipeline[1]

Softmax warp: acquire waits on C consumer (buffer free), commit arrives at C producer (correction written to SMEM).

Returns:

RolePipeline

consumer_c

consumer_c(self) -> RolePipeline[1, False]

Correction warp: wait on C producer (correction ready), release arrives at C consumer (correction consumed).

Returns:

RolePipeline

producer_o_lo

producer_o_lo(self) -> RolePipeline[1, cta_group=2]

MMA warp P@V_lo pipeline. Acquire: wait on PO_lo (P ready + previous O_lo rescaled). Commit: mma_arrive at O_mma_lo (V_lo accumulation done).

Returns:

RolePipeline

producer_o_hi

producer_o_hi(self) -> RolePipeline[1, cta_group=2]

MMA warp P@V_hi pipeline. Acquire: wait on PO_hi (previous O_hi rescaled). Commit: mma_arrive at O_mma_hi (V_hi accumulation done).

Returns:

RolePipeline

consumer_o_lo

consumer_o_lo(self) -> RolePipeline[1, False]

Correction warp O_lo pipeline. Wait: on O_mma_lo (V_lo accumulation done, O_lo safe to read). Release: arrive at PO_lo (O_lo rescaling done, 128 threads).

Returns:

RolePipeline

consumer_o_hi

consumer_o_hi(self) -> RolePipeline[1, False]

Correction warp O_hi pipeline. Wait: on O_mma_hi (V_hi accumulation done, O_hi safe to read). Release: arrive at PO_hi (O_hi rescaling done, 128 threads).

Returns:

RolePipeline

po_lo_mbar

po_lo_mbar(self) -> MBarType

Raw pointer to PO_lo barrier.

Softmax arrives here (128 threads) after writing P to SMEM. Correction also arrives (128 threads) via consumer_o_lo().release(). MMA waits for all 256 arrives before P@V_lo.

Returns:

MBarType

po_hi_mbar

po_hi_mbar(self) -> MBarType

Raw pointer to PO_hi barrier.

Correction arrives here (128 threads) via consumer_o_hi().release(). MMA waits for all 128 arrives before P@V_hi.

Returns:

MBarType

get_kv_mbars

get_kv_mbars(self) -> MBarType

Base pointer for KV pipeline barriers.

Layout: 2 × num_kv_stages barriers in {producer, consumer} pairs. Used by load warp (producer) and MMA warp (consumer).

Returns:

MBarType

num_mbars

static num_mbars() -> UInt32

Total number of mbarriers managed by this struct.

Returns:

UInt32

Was this page helpful?