IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

FA4MiscMBars

struct FA4MiscMBars[*, num_qk_stages: Int = 1, num_pv_stages: Int = 1, num_kv_stages: Int = 2, use_order_barriers: Bool = True, use_fused_kv: Bool = False, pair_cta: Bool = False, num_qo: Int = 2]

Manages all mbarrier resources for FA4.

This struct consolidates all mbarrier management including:

  • S barriers (score MMA synchronization)
  • C barriers (correction synchronization)
  • Order barriers (softmax ordering)
  • Q1Sync barriers (Q tile synchronization)
  • K/V pipeline barriers (separate K and V)
  • O pipeline barriers

Memory layout (count=128 first, then count=1): [S0_cons] [S1_cons] [C0] [C1] [Order*] | [S0_prod] [S1_prod] [Q1Sync**] [K] [V] [O_prod] *Order barriers only present when use_order_barriers=True **Q1Sync barriers only present when num_qo == 2

Parameters​

  • ​num_qk_stages (Int): Number of stages for Q@K' MMA (K loading can be staged).
  • ​num_pv_stages (Int): Number of stages for P@V MMA (P writing can be staged).
  • ​num_kv_stages (Int): Number of KV buffer stages for double/triple buffering.
  • ​use_order_barriers (Bool): When True, allocate order barriers to prevent softmax warp group overlap. When False, order barriers are omitted.
  • ​use_fused_kv (Bool): Whether the K and V share the same pipeline, or separate.
  • ​pair_cta (Bool): Whether to use 1-cta or 2-cta implementation.
  • ​num_qo (Int): Number of Q tiles per CTA. When 1, the Q1Sync slot is collapsed and K_offset shifts down by num_qk_stages. Must be 2 for any caller of q1_wait_mbar().

Fields​

  • ​mbar_base (MBarType):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

C0_offset​

comptime C0_offset = (2 * num_pv_stages)

C1_offset​

comptime C1_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].C0_offset + 2)

K_barriers​

comptime K_barriers = ((2 * num_qk_stages) * num_kv_stages)

K_offset​

comptime K_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].Q1SyncIdx + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].Q1Sync_count)

num_order_barriers​

comptime num_order_barriers = 2 if use_order_barriers else 0

number_warpgroup_count​

comptime number_warpgroup_count = FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].S0_producer_offset

O_producer_offset​

comptime O_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].V_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].V_barriers)

order_offset​

comptime order_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].C1_offset + 2)

Q1Sync_count​

comptime Q1Sync_count = num_qk_stages if (num_qo == 2) else 0

Q1SyncIdx​

comptime Q1SyncIdx = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].S1_producer_offset + 1)

S0_consumer_offset​

comptime S0_consumer_offset = 0

S0_producer_offset​

comptime S0_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].order_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].num_order_barriers)

S1_consumer_offset​

comptime S1_consumer_offset = num_pv_stages

S1_producer_offset​

comptime S1_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].S0_producer_offset + 1)

size​

comptime size = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].O_producer_offset + 2)

SPipelineConsumer​

comptime SPipelineConsumer = RolePipeline[1, False, consumer_sub_stages=num_pv_stages]

SPipelineProducer​

comptime SPipelineProducer = RolePipeline[1, consumer_sub_stages=num_pv_stages]

V_barriers​

comptime V_barriers = 0 if use_fused_kv else (2 * num_kv_stages)

V_offset​

comptime V_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].K_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].K_barriers)

Methods​

__init__​

def __init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

init​

def init(self, *, lane_idx: Int32)

producer_s0​

def producer_s0(self) -> Self.SPipelineProducer

Get S producer for warp group 0.

Returns:

Self.SPipelineProducer

producer_s1​

def producer_s1(self) -> Self.SPipelineProducer

Get S producer for warp group 1.

Returns:

Self.SPipelineProducer

consumer_s​

def consumer_s(self, wg_idx: UInt32) -> Self.SPipelineConsumer

Get S consumer for given warp group.

Returns:

Self.SPipelineConsumer

consumer_c0​

def consumer_c0(self) -> RolePipeline[1, False]

Returns:

RolePipeline[1, False]

consumer_c1​

def consumer_c1(self) -> RolePipeline[1, False]

Returns:

RolePipeline[1, False]

producer_c​

def producer_c(self, wg_idx: UInt32) -> RolePipeline[1]

Returns:

RolePipeline[1]

pipeline_order_wait​

def pipeline_order_wait(self, wg_idx: UInt32) -> MBarType

Returns:

MBarType

pipeline_order_arrive​

def pipeline_order_arrive(self, wg_idx: UInt32) -> MBarType

Returns:

MBarType

q1_wait_mbar​

def q1_wait_mbar(self) -> MBarType

Returns:

MBarType

get_k_mbars​

def get_k_mbars(self) -> MBarType

Returns base pointer for K pipeline barriers.

Returns:

MBarType

get_v_mbars​

def get_v_mbars(self) -> MBarType

Returns base pointer for V pipeline barriers. In fused mode, returns the same as get_k_mbars (shared pipeline).

Returns:

MBarType

combined_p_o_consumer​

def combined_p_o_consumer(self, wg_idx: UInt32) -> MBarType

Combined P+O consumer barrier for given warp group.

Arrived at by BOTH softmax (P ready) and correction (O rescaled). Returns S_consumer[0] for wg_idx=0 or wg_idx=1.

Returns:

MBarType

consumer_o​

def consumer_o(self) -> RolePipeline[2, False, consumer_sub_stages=num_pv_stages]

Get O consumer pipeline.

Wait side: O_producer barriers (stride 1, indexed by stage). Release side: combined S+O barriers (S_consumer[0] per wg, stride num_pv_stages).

Returns:

RolePipeline[2, False, consumer_sub_stages=num_pv_stages]

producer_o0​

def producer_o0(self) -> RolePipeline[1]

Get O producer for warp group 0.

Returns:

RolePipeline[1]

producer_o1​

def producer_o1(self) -> RolePipeline[1]

Get O producer for warp group 1.

Returns:

RolePipeline[1]

num_mbars​

static def num_mbars() -> UInt32

Returns:

UInt32