Mojo struct
Depth512MBars
struct Depth512MBars[num_kv_stages: Int, split_o: Bool = True]
Manages all mbarrier resources for depth=256/512 pair-CTA attention.
Parameters
- num_kv_stages (
Int): Number of fused KV pipeline buffer slots. - split_o (
Bool): True for depth=512 (O_lo/O_hi split, 10 fixed barriers), False for depth=256 (single O, 8 fixed barriers).
Fields
- mbar_base (
MBarType):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
C_consumer_offset
comptime C_consumer_offset = (Depth512MBars[num_kv_stages, split_o].C_producer_offset + 1)
C_producer_offset
comptime C_producer_offset = (Depth512MBars[num_kv_stages, split_o].S_odd_consumer_offset + 1)
KV_barriers
comptime KV_barriers = (2 * num_kv_stages)
KV_offset
comptime KV_offset = Depth512MBars[num_kv_stages, split_o].num_fixed
num_fixed
comptime num_fixed = ((Depth512MBars[num_kv_stages, split_o].O_mma_lo_offset + 1) + 1 if split_o else 0)
num_high_count_barriers
comptime num_high_count_barriers = (5 + 1 if split_o else 0)
O_mma_hi_offset
comptime O_mma_hi_offset = (Depth512MBars[num_kv_stages, split_o].O_mma_lo_offset + 1)
O_mma_lo_offset
comptime O_mma_lo_offset = (Depth512MBars[num_kv_stages, split_o].S_odd_producer_offset + 1)
PO_hi_offset
comptime PO_hi_offset = (Depth512MBars[num_kv_stages, split_o].PO_lo_offset + 1)
PO_lo_offset
comptime PO_lo_offset = 0
S_even_consumer_offset
comptime S_even_consumer_offset = ((Depth512MBars[num_kv_stages, split_o].PO_lo_offset + 1) + 1 if split_o else 0)
S_even_producer_offset
comptime S_even_producer_offset = (Depth512MBars[num_kv_stages, split_o].C_consumer_offset + 1)
S_odd_consumer_offset
comptime S_odd_consumer_offset = (Depth512MBars[num_kv_stages, split_o].S_even_consumer_offset + 1)
S_odd_producer_offset
comptime S_odd_producer_offset = (Depth512MBars[num_kv_stages, split_o].S_even_producer_offset + 1)
size
comptime size = (Depth512MBars[num_kv_stages, split_o].num_fixed + Depth512MBars[num_kv_stages, split_o].KV_barriers)
SPipelineConsumer
comptime SPipelineConsumer = RolePipeline[1, False]
SPipelineProducer
comptime SPipelineProducer = RolePipeline[1, cta_group=2]
Methods
__init__
__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
Wrap a base mbarrier pointer from Depth512AttentionSMem.mbar_base().
init
init(self, *, lane_idx: Int32)
Initialize all barriers. One call per warp lane (lane_idx = tid % 32).
With num_kv_stages <= 11, total size <= 32 = WARP_SIZE, so a single wave covers all barriers.
producer_s_even
producer_s_even(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineProducer
MMA warp: acquire waits on S_even consumer (buffer free), commit arrives at S_even producer (S written to TMEM).
Returns:
Depth512MBars
producer_s_odd
producer_s_odd(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineProducer
MMA warp: acquire waits on S_odd consumer (buffer free), commit arrives at S_odd producer (S written to TMEM).
Returns:
Depth512MBars
consumer_s_even
consumer_s_even(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineConsumer
Softmax warp: wait on S_even producer (S ready in TMEM), release arrives at S_even consumer (S loaded, buffer free).
Returns:
Depth512MBars
consumer_s_odd
consumer_s_odd(self) -> Depth512MBars[num_kv_stages, split_o].SPipelineConsumer
Softmax warp: wait on S_odd producer (S ready in TMEM), release arrives at S_odd consumer (S loaded, buffer free).
Returns:
Depth512MBars
producer_c
producer_c(self) -> RolePipeline[1]
Softmax warp: acquire waits on C consumer (buffer free), commit arrives at C producer (correction written to SMEM).
Returns:
RolePipeline
consumer_c
consumer_c(self) -> RolePipeline[1, False]
Correction warp: wait on C producer (correction ready), release arrives at C consumer (correction consumed).
Returns:
RolePipeline
producer_o_lo
producer_o_lo(self) -> RolePipeline[1, cta_group=2]
MMA warp P@V_lo pipeline. Acquire: wait on PO_lo (P ready + previous O_lo rescaled). Commit: mma_arrive at O_mma_lo (V_lo accumulation done).
Returns:
RolePipeline
producer_o_hi
producer_o_hi(self) -> RolePipeline[1, cta_group=2]
MMA warp P@V_hi pipeline. Acquire: wait on PO_hi (previous O_hi rescaled). Commit: mma_arrive at O_mma_hi (V_hi accumulation done).
Returns:
RolePipeline
consumer_o_lo
consumer_o_lo(self) -> RolePipeline[1, False]
Correction warp O_lo pipeline. Wait: on O_mma_lo (V_lo accumulation done, O_lo safe to read). Release: arrive at PO_lo (O_lo rescaling done, 128 threads).
Returns:
RolePipeline
consumer_o_hi
consumer_o_hi(self) -> RolePipeline[1, False]
Correction warp O_hi pipeline. Wait: on O_mma_hi (V_hi accumulation done, O_hi safe to read). Release: arrive at PO_hi (O_hi rescaling done, 128 threads).
Returns:
RolePipeline
po_lo_mbar
po_lo_mbar(self) -> MBarType
Raw pointer to PO_lo barrier.
Softmax arrives here (128 threads) after writing P to SMEM. Correction also arrives (128 threads) via consumer_o_lo().release(). MMA waits for all 256 arrives before P@V_lo.
Returns:
MBarType
po_hi_mbar
po_hi_mbar(self) -> MBarType
Raw pointer to PO_hi barrier.
Correction arrives here (128 threads) via consumer_o_hi().release(). MMA waits for all 128 arrives before P@V_hi.
Returns:
MBarType
get_kv_mbars
get_kv_mbars(self) -> MBarType
Base pointer for KV pipeline barriers.
Layout: 2 × num_kv_stages barriers in {producer, consumer} pairs. Used by load warp (producer) and MMA warp (consumer).
Returns:
MBarType
num_mbars
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!