IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Depth512MBars

struct Depth512MBars[num_kv_stages: Int, split_o: Bool = True]

Manages all mbarrier resources for depth=256/512 pair-CTA attention.

Parameters​

  • ​num_kv_stages (Int): Number of fused KV pipeline buffer slots.
  • ​split_o (Bool): True for depth=512 (O_lo/O_hi split, 10 fixed barriers), False for depth=256 (single O, 8 fixed barriers).

Fields​

  • ​mbar_base (MBarType):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

C_consumer_offset​

comptime C_consumer_offset = (Depth512MBars[num_kv_stages, split_o].C_producer_offset + 1)

C_producer_offset​

comptime C_producer_offset = (Depth512MBars[num_kv_stages, split_o].S_odd_consumer_offset + 1)

KV_barriers​

comptime KV_barriers = (2 * num_kv_stages)

KV_offset​

comptime KV_offset = Depth512MBars[num_kv_stages, split_o].num_fixed

num_fixed​

comptime num_fixed = ((Depth512MBars[num_kv_stages, split_o].O_mma_lo_offset + 1) + 1 if split_o else 0)

num_high_count_barriers​

comptime num_high_count_barriers = (5 + 1 if split_o else 0)

O_mma_hi_offset​

comptime O_mma_hi_offset = (Depth512MBars[num_kv_stages, split_o].O_mma_lo_offset + 1)

O_mma_lo_offset​

comptime O_mma_lo_offset = (Depth512MBars[num_kv_stages, split_o].S_odd_producer_offset + 1)

PO_hi_offset​

comptime PO_hi_offset = (Depth512MBars[num_kv_stages, split_o].PO_lo_offset + 1)

PO_lo_offset​

comptime PO_lo_offset = 0

S_even_consumer_offset​

comptime S_even_consumer_offset = ((Depth512MBars[num_kv_stages, split_o].PO_lo_offset + 1) + 1 if split_o else 0)

S_even_producer_offset​

comptime S_even_producer_offset = (Depth512MBars[num_kv_stages, split_o].C_consumer_offset + 1)

S_odd_consumer_offset​

comptime S_odd_consumer_offset = (Depth512MBars[num_kv_stages, split_o].S_even_consumer_offset + 1)

S_odd_producer_offset​

comptime S_odd_producer_offset = (Depth512MBars[num_kv_stages, split_o].S_even_producer_offset + 1)

size​

comptime size = (Depth512MBars[num_kv_stages, split_o].num_fixed + Depth512MBars[num_kv_stages, split_o].KV_barriers)

SPipelineConsumer​

comptime SPipelineConsumer = RolePipeline[1, False]

SPipelineProducer​

comptime SPipelineProducer = RolePipeline[1, cta_group=2]

Methods​

__init__​

def __init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

Wrap a base mbarrier pointer from Depth512AttentionSMem.mbar_base().

init​

def init(self, *, lane_idx: Int32)

Initialize all barriers. One call per warp lane (lane_idx = tid % 32).

With num_kv_stages <= 11, total size <= 32 = WARP_SIZE, so a single wave covers all barriers.

producer_s_even​

def producer_s_even(self) -> Self.SPipelineProducer

MMA warp: acquire waits on S_even consumer (buffer free), commit arrives at S_even producer (S written to TMEM).

Returns:

Self.SPipelineProducer

producer_s_odd​

def producer_s_odd(self) -> Self.SPipelineProducer

MMA warp: acquire waits on S_odd consumer (buffer free), commit arrives at S_odd producer (S written to TMEM).

Returns:

Self.SPipelineProducer

consumer_s_even​

def consumer_s_even(self) -> Self.SPipelineConsumer

Softmax warp: wait on S_even producer (S ready in TMEM), release arrives at S_even consumer (S loaded, buffer free).

Returns:

Self.SPipelineConsumer

consumer_s_odd​

def consumer_s_odd(self) -> Self.SPipelineConsumer

Softmax warp: wait on S_odd producer (S ready in TMEM), release arrives at S_odd consumer (S loaded, buffer free).

Returns:

Self.SPipelineConsumer

producer_c​

def producer_c(self) -> RolePipeline[1]

Softmax warp: acquire waits on C consumer (buffer free), commit arrives at C producer (correction written to SMEM).

Returns:

RolePipeline[1]

consumer_c​

def consumer_c(self) -> RolePipeline[1, False]

Correction warp: wait on C producer (correction ready), release arrives at C consumer (correction consumed).

Returns:

RolePipeline[1, False]

producer_o_lo​

def producer_o_lo(self) -> RolePipeline[1, cta_group=2]

MMA warp P@V_lo pipeline. Acquire: wait on PO_lo (P ready + previous O_lo rescaled). Commit: mma_arrive at O_mma_lo (V_lo accumulation done).

Returns:

RolePipeline[1, cta_group=2]

producer_o_hi​

def producer_o_hi(self) -> RolePipeline[1, cta_group=2]

MMA warp P@V_hi pipeline. Acquire: wait on PO_hi (previous O_hi rescaled). Commit: mma_arrive at O_mma_hi (V_hi accumulation done).

Returns:

RolePipeline[1, cta_group=2]

consumer_o_lo​

def consumer_o_lo(self) -> RolePipeline[1, False]

Correction warp O_lo pipeline. Wait: on O_mma_lo (V_lo accumulation done, O_lo safe to read). Release: arrive at PO_lo (O_lo rescaling done, 128 threads).

Returns:

RolePipeline[1, False]

consumer_o_hi​

def consumer_o_hi(self) -> RolePipeline[1, False]

Correction warp O_hi pipeline. Wait: on O_mma_hi (V_hi accumulation done, O_hi safe to read). Release: arrive at PO_hi (O_hi rescaling done, 128 threads).

Returns:

RolePipeline[1, False]

po_lo_mbar​

def po_lo_mbar(self) -> MBarType

Raw pointer to PO_lo barrier.

Softmax arrives here (128 threads) after writing P to SMEM. Correction also arrives (128 threads) via consumer_o_lo().release(). MMA waits for all 256 arrives before P@V_lo.

Returns:

MBarType

po_hi_mbar​

def po_hi_mbar(self) -> MBarType

Raw pointer to PO_hi barrier.

Correction arrives here (128 threads) via consumer_o_hi().release(). MMA waits for all 128 arrives before P@V_hi.

Returns:

MBarType

get_kv_mbars​

def get_kv_mbars(self) -> MBarType

Base pointer for KV pipeline barriers.

Layout: 2 Γ— num_kv_stages barriers in {producer, consumer} pairs. Used by load warp (producer) and MMA warp (consumer).

Returns:

MBarType

num_mbars​

static def num_mbars() -> UInt32

Total number of mbarriers managed by this struct.

Returns:

UInt32