Skip to main content

Mojo struct

Depth512MBars

struct Depth512MBars[num_kv_stages: Int, split_o: Bool = True]

Manages all mbarrier resources for depth=256/512 pair-CTA attention.

Parameters

  • num_kv_stages (Int): Number of fused KV pipeline buffer slots.
  • split_o (Bool): True for depth=512 (O_lo/O_hi split, 10 fixed barriers), False for depth=256 (single O, 8 fixed barriers).

Fields

  • mbar_base (MBarType):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

C_consumer_offset

comptime C_consumer_offset = (Depth512MBars[num_kv_stages, split_o].C_producer_offset + 1)

C_producer_offset

comptime C_producer_offset = (Depth512MBars[num_kv_stages, split_o].S_odd_consumer_offset + 1)

KV_barriers

comptime KV_barriers = (2 * num_kv_stages)

KV_offset

comptime KV_offset = Depth512MBars[num_kv_stages, split_o].num_fixed

num_fixed

comptime num_fixed = ((Depth512MBars[num_kv_stages, split_o].O_mma_lo_offset + 1) + 1 if split_o else 0)

num_high_count_barriers

comptime num_high_count_barriers = (5 + 1 if split_o else 0)

O_mma_hi_offset

comptime O_mma_hi_offset = (Depth512MBars[num_kv_stages, split_o].O_mma_lo_offset + 1)

O_mma_lo_offset

comptime O_mma_lo_offset = (Depth512MBars[num_kv_stages, split_o].S_odd_producer_offset + 1)

PO_hi_offset

comptime PO_hi_offset = (Depth512MBars[num_kv_stages, split_o].PO_lo_offset + 1)

PO_lo_offset

comptime PO_lo_offset = 0

S_even_consumer_offset

comptime S_even_consumer_offset = ((Depth512MBars[num_kv_stages, split_o].PO_lo_offset + 1) + 1 if split_o else 0)

S_even_producer_offset

comptime S_even_producer_offset = (Depth512MBars[num_kv_stages, split_o].C_consumer_offset + 1)

S_odd_consumer_offset

comptime S_odd_consumer_offset = (Depth512MBars[num_kv_stages, split_o].S_even_consumer_offset + 1)

S_odd_producer_offset

comptime S_odd_producer_offset = (Depth512MBars[num_kv_stages, split_o].S_even_producer_offset + 1)

size

comptime size = (Depth512MBars[num_kv_stages, split_o].num_fixed + Depth512MBars[num_kv_stages, split_o].KV_barriers)

SPipelineConsumer

comptime SPipelineConsumer = RolePipeline[1, False]

SPipelineProducer

comptime SPipelineProducer = RolePipeline[1, cta_group=2]

Methods

__init__

__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

Wrap a base mbarrier pointer from Depth512AttentionSMem.mbar_base().

init

init(self, *, lane_idx: Int32)

Initialize all barriers. One call per warp lane (lane_idx = tid % 32).

With num_kv_stages <= 11, total size <= 32 = WARP_SIZE, so a single wave covers all barriers.

producer_s_even

producer_s_even(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineProducer

MMA warp: acquire waits on S_even consumer (buffer free), commit arrives at S_even producer (S written to TMEM).

Returns:

Depth512MBars

producer_s_odd

producer_s_odd(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineProducer

MMA warp: acquire waits on S_odd consumer (buffer free), commit arrives at S_odd producer (S written to TMEM).

Returns:

Depth512MBars

consumer_s_even

consumer_s_even(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineConsumer

Softmax warp: wait on S_even producer (S ready in TMEM), release arrives at S_even consumer (S loaded, buffer free).

Returns:

Depth512MBars

consumer_s_odd

consumer_s_odd(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineConsumer

Softmax warp: wait on S_odd producer (S ready in TMEM), release arrives at S_odd consumer (S loaded, buffer free).

Returns:

Depth512MBars

producer_c

producer_c(self) -> RolePipeline[1]

Softmax warp: acquire waits on C consumer (buffer free), commit arrives at C producer (correction written to SMEM).

Returns:

RolePipeline

consumer_c

consumer_c(self) -> RolePipeline[1, False]

Correction warp: wait on C producer (correction ready), release arrives at C consumer (correction consumed).

Returns:

RolePipeline

producer_o_lo

producer_o_lo(self) -> RolePipeline[1, cta_group=2]

MMA warp P@V_lo pipeline. Acquire: wait on PO_lo (P ready + previous O_lo rescaled). Commit: mma_arrive at O_mma_lo (V_lo accumulation done).

Returns:

RolePipeline

producer_o_hi

producer_o_hi(self) -> RolePipeline[1, cta_group=2]

MMA warp P@V_hi pipeline. Acquire: wait on PO_hi (previous O_hi rescaled). Commit: mma_arrive at O_mma_hi (V_hi accumulation done).

Returns:

RolePipeline

consumer_o_lo

consumer_o_lo(self) -> RolePipeline[1, False]

Correction warp O_lo pipeline. Wait: on O_mma_lo (V_lo accumulation done, O_lo safe to read). Release: arrive at PO_lo (O_lo rescaling done, 128 threads).

Returns:

RolePipeline

consumer_o_hi

consumer_o_hi(self) -> RolePipeline[1, False]

Correction warp O_hi pipeline. Wait: on O_mma_hi (V_hi accumulation done, O_hi safe to read). Release: arrive at PO_hi (O_hi rescaling done, 128 threads).

Returns:

RolePipeline

po_lo_mbar

po_lo_mbar(self) -> MBarType

Raw pointer to PO_lo barrier.

Softmax arrives here (128 threads) after writing P to SMEM. Correction also arrives (128 threads) via consumer_o_lo().release(). MMA waits for all 256 arrives before P@V_lo.

Returns:

MBarType

po_hi_mbar

po_hi_mbar(self) -> MBarType

Raw pointer to PO_hi barrier.

Correction arrives here (128 threads) via consumer_o_hi().release(). MMA waits for all 128 arrives before P@V_hi.

Returns:

MBarType

get_kv_mbars

get_kv_mbars(self) -> MBarType

Base pointer for KV pipeline barriers.

Layout: 2 × num_kv_stages barriers in {producer, consumer} pairs. Used by load warp (producer) and MMA warp (consumer).

Returns:

MBarType

num_mbars

static num_mbars() -> UInt32

Total number of mbarriers managed by this struct.

Returns:

UInt32

Was this page helpful?