Mojo struct
Depth512MBars
struct Depth512MBars[num_kv_stages: Int]
Manages all mbarrier resources for depth=512 pair-CTA attention.
Parameters
- num_kv_stages (
Int): Number of fused KV pipeline buffer slots.
Fields
- mbar_base (
MBarType):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
C_consumer_offset
comptime C_consumer_offset = 5
C_producer_offset
comptime C_producer_offset = 4
KV_barriers
comptime KV_barriers = (2 * num_kv_stages)
KV_offset
comptime KV_offset = 10
num_high_count_barriers
comptime num_high_count_barriers = 6
O_mma_hi_offset
comptime O_mma_hi_offset = 9
O_mma_lo_offset
comptime O_mma_lo_offset = 8
PO_hi_offset
comptime PO_hi_offset = 1
PO_lo_offset
comptime PO_lo_offset = 0
S_even_consumer_offset
comptime S_even_consumer_offset = 2
S_even_producer_offset
comptime S_even_producer_offset = 6
S_odd_consumer_offset
comptime S_odd_consumer_offset = 3
S_odd_producer_offset
comptime S_odd_producer_offset = 7
size
comptime size = (10 + Depth512MBars[num_kv_stages].KV_barriers)
SPipelineConsumer
comptime SPipelineConsumer = RolePipeline[1, False]
SPipelineProducer
comptime SPipelineProducer = RolePipeline[1, cta_group=2]
Methods
__init__
__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
Wrap a base mbarrier pointer from Depth512AttentionSMem.mbar_base().
init
init(self, *, lane_idx: Int32)
Initialize all barriers. One call per warp lane (lane_idx = tid % 32).
With num_kv_stages <= 11, total size <= 32 = WARP_SIZE, so a single wave covers all barriers.
producer_s_even
producer_s_even(self) -> Depth512MBars[num_kv_stages].SPipelineProducer
MMA warp: acquire waits on S_even consumer (buffer free), commit arrives at S_even producer (S written to TMEM).
Returns:
Depth512MBars
producer_s_odd
producer_s_odd(self) -> Depth512MBars[num_kv_stages].SPipelineProducer
MMA warp: acquire waits on S_odd consumer (buffer free), commit arrives at S_odd producer (S written to TMEM).
Returns:
Depth512MBars
consumer_s_even
consumer_s_even(self) -> Depth512MBars[num_kv_stages].SPipelineConsumer
Softmax warp: wait on S_even producer (S ready in TMEM), release arrives at S_even consumer (S loaded, buffer free).
Returns:
Depth512MBars
consumer_s_odd
consumer_s_odd(self) -> Depth512MBars[num_kv_stages].SPipelineConsumer
Softmax warp: wait on S_odd producer (S ready in TMEM), release arrives at S_odd consumer (S loaded, buffer free).
Returns:
Depth512MBars
producer_c
producer_c(self) -> RolePipeline[1]
Softmax warp: acquire waits on C consumer (buffer free), commit arrives at C producer (correction written to SMEM).
Returns:
RolePipeline
consumer_c
consumer_c(self) -> RolePipeline[1, False]
Correction warp: wait on C producer (correction ready), release arrives at C consumer (correction consumed).
Returns:
RolePipeline
producer_o_lo
producer_o_lo(self) -> RolePipeline[1, cta_group=2]
MMA warp P@V_lo pipeline. Acquire: wait on PO_lo (P ready + previous O_lo rescaled). Commit: mma_arrive at O_mma_lo (V_lo accumulation done).
Returns:
RolePipeline
producer_o_hi
producer_o_hi(self) -> RolePipeline[1, cta_group=2]
MMA warp P@V_hi pipeline. Acquire: wait on PO_hi (previous O_hi rescaled). Commit: mma_arrive at O_mma_hi (V_hi accumulation done).
Returns:
RolePipeline
consumer_o_lo
consumer_o_lo(self) -> RolePipeline[1, False]
Correction warp O_lo pipeline. Wait: on O_mma_lo (V_lo accumulation done, O_lo safe to read). Release: arrive at PO_lo (O_lo rescaling done, 128 threads).
Returns:
RolePipeline
consumer_o_hi
consumer_o_hi(self) -> RolePipeline[1, False]
Correction warp O_hi pipeline. Wait: on O_mma_hi (V_hi accumulation done, O_hi safe to read). Release: arrive at PO_hi (O_hi rescaling done, 128 threads).
Returns:
RolePipeline
po_lo_mbar
po_lo_mbar(self) -> MBarType
Raw pointer to PO_lo barrier.
Softmax arrives here (128 threads) after writing P to SMEM. Correction also arrives (128 threads) via consumer_o_lo().release(). MMA waits for all 256 arrives before P@V_lo.
Returns:
MBarType
po_hi_mbar
po_hi_mbar(self) -> MBarType
Raw pointer to PO_hi barrier.
Correction arrives here (128 threads) via consumer_o_hi().release(). MMA waits for all 128 arrives before P@V_hi.
Returns:
MBarType
get_kv_mbars
get_kv_mbars(self) -> MBarType
Base pointer for KV pipeline barriers.
Layout: 2 × num_kv_stages barriers in {producer, consumer} pairs. Used by load warp (producer) and MMA warp (consumer).
Returns:
MBarType
num_mbars
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!