For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
FA4MiscMBars
struct FA4MiscMBars[*, num_qk_stages: Int = Int(1), num_pv_stages: Int = Int(1), num_kv_stages: Int = Int(2), use_order_barriers: Bool = True, use_fused_kv: Bool = False, pair_cta: Bool = False, num_qo: Int = Int(2)]
Manages all mbarrier resources for FA4.
This struct consolidates all mbarrier management including:
- S barriers (score MMA synchronization)
- C barriers (correction synchronization)
- Order barriers (softmax ordering)
- Q1Sync barriers (Q tile synchronization)
- K/V pipeline barriers (separate K and V)
- O pipeline barriers
Memory layout (count=128 first, then count=1): [S0_cons] [S1_cons] [C0] [C1] [Order*] | [S0_prod] [S1_prod] [Q1Sync**] [K] [V] [O_prod] *Order barriers only present when use_order_barriers=True **Q1Sync barriers only present when num_qo == 2
Parametersβ
- βnum_qk_stages (
Int): Number of stages for Q@K' MMA (K loading can be staged). - βnum_pv_stages (
Int): Number of stages for P@V MMA (P writing can be staged). - βnum_kv_stages (
Int): Number of KV buffer stages for double/triple buffering. - βuse_order_barriers (
Bool): When True, allocate order barriers to prevent softmax warp group overlap. When False, order barriers are omitted. - βuse_fused_kv (
Bool): Whether the K and V share the same pipeline, or separate. - βpair_cta (
Bool): Whether to use 1-cta or 2-cta implementation. - βnum_qo (
Int): Number of Q tiles per CTA. When 1, theQ1Syncslot is collapsed andK_offsetshifts down bynum_qk_stages. Must be 2 for any caller ofq1_wait_mbar().
Fieldsβ
- βmbar_base (
MBarType):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
C0_offsetβ
comptime C0_offset = (Int(2) * num_pv_stages)
C1_offsetβ
comptime C1_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].C0_offset + Int(2))
K_barriersβ
comptime K_barriers = ((Int(2) * num_qk_stages) * num_kv_stages)
K_offsetβ
comptime K_offset = (Int((add (mul num_pv_stages, 2), Int(2) if use_order_barriers else Int(0), 6)) + num_qk_stages if (eq num_qo, 2) else Int(0))
num_order_barriersβ
comptime num_order_barriers = Int(2) if use_order_barriers else Int(0)
number_warpgroup_countβ
comptime number_warpgroup_count = FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].S0_producer_offset
O_producer_offsetβ
comptime O_producer_offset = (Int((add (mul num_kv_stages, num_qk_stages, 2), (mul num_pv_stages, 2), num_qk_stages if (eq num_qo, 2) else Int(0), Int(2) if use_order_barriers else Int(0), 6)) + Int(0) if use_fused_kv else Int((mul num_kv_stages, 2)))
order_offsetβ
comptime order_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].C1_offset + Int(2))
Q1Sync_countβ
comptime Q1Sync_count = num_qk_stages if (num_qo == Int(2)) else Int(0)
Q1SyncIdxβ
comptime Q1SyncIdx = (Int((add (mul num_pv_stages, 2), Int(2) if use_order_barriers else Int(0), 5)) + Int(1))
S0_consumer_offsetβ
comptime S0_consumer_offset = 0
S0_producer_offsetβ
comptime S0_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta, num_qo=num_qo].order_offset + Int(2) if use_order_barriers else Int(0))
S1_consumer_offsetβ
comptime S1_consumer_offset = num_pv_stages
S1_producer_offsetβ
comptime S1_producer_offset = (Int((add (mul num_pv_stages, 2), Int(2) if use_order_barriers else Int(0), 4)) + Int(1))
sizeβ
comptime size = (Int((add (mul num_kv_stages, num_qk_stages, 2), (mul num_pv_stages, 2), num_qk_stages if (eq num_qo, 2) else Int(0), Int(0) if use_fused_kv else Int((mul num_kv_stages, 2)), Int(2) if use_order_barriers else Int(0), 6)) + Int(2))
SPipelineConsumerβ
comptime SPipelineConsumer = RolePipeline[Int(1), False, consumer_sub_stages=num_pv_stages]
SPipelineProducerβ
comptime SPipelineProducer = RolePipeline[Int(1), consumer_sub_stages=num_pv_stages]
V_barriersβ
comptime V_barriers = Int(0) if use_fused_kv else (Int(2) * num_kv_stages)
V_offsetβ
comptime V_offset = (Int((add (mul num_pv_stages, 2), num_qk_stages if (eq num_qo, 2) else Int(0), Int(2) if use_order_barriers else Int(0), 6)) + Int((mul num_kv_stages, num_qk_stages, 2)))
Methodsβ
__init__β
def __init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
initβ
def init(self, *, lane_idx: Int32)
producer_s0β
def producer_s0(self) -> Self.SPipelineProducer
Get S producer for warp group 0.
Returns:
Self.SPipelineProducer
producer_s1β
def producer_s1(self) -> Self.SPipelineProducer
Get S producer for warp group 1.
Returns:
Self.SPipelineProducer
consumer_sβ
def consumer_s(self, wg_idx: UInt32) -> Self.SPipelineConsumer
Get S consumer for given warp group.
Returns:
Self.SPipelineConsumer
consumer_c0β
def consumer_c0(self) -> RolePipeline[Int(1), False]
Returns:
RolePipeline[Int(1), False]
consumer_c1β
def consumer_c1(self) -> RolePipeline[Int(1), False]
Returns:
RolePipeline[Int(1), False]
producer_cβ
def producer_c(self, wg_idx: UInt32) -> RolePipeline[Int(1)]
Returns:
RolePipeline[Int(1)]
pipeline_order_waitβ
def pipeline_order_wait(self, wg_idx: UInt32) -> MBarType
Returns:
MBarType
pipeline_order_arriveβ
def pipeline_order_arrive(self, wg_idx: UInt32) -> MBarType
Returns:
MBarType
q1_wait_mbarβ
def q1_wait_mbar(self) -> MBarType
Returns:
MBarType
get_k_mbarsβ
def get_k_mbars(self) -> MBarType
Returns base pointer for K pipeline barriers.
Returns:
MBarType
get_v_mbarsβ
def get_v_mbars(self) -> MBarType
Returns base pointer for V pipeline barriers. In fused mode, returns the same as get_k_mbars (shared pipeline).
Returns:
MBarType
combined_p_o_consumerβ
def combined_p_o_consumer(self, wg_idx: UInt32) -> MBarType
Combined P+O consumer barrier for given warp group.
Arrived at by BOTH softmax (P ready) and correction (O rescaled). Returns S_consumer[0] for wg_idx=0 or wg_idx=1.
Returns:
MBarType
consumer_oβ
def consumer_o(self) -> RolePipeline[Int(2), False, consumer_sub_stages=num_pv_stages]
Get O consumer pipeline.
Wait side: O_producer barriers (stride 1, indexed by stage). Release side: combined S+O barriers (S_consumer[0] per wg, stride num_pv_stages).
Returns:
RolePipeline[Int(2), False, consumer_sub_stages=num_pv_stages]
producer_o0β
def producer_o0(self) -> RolePipeline[Int(1)]
Get O producer for warp group 0.
Returns:
RolePipeline[Int(1)]
producer_o1β
def producer_o1(self) -> RolePipeline[Int(1)]
Get O producer for warp group 1.
Returns:
RolePipeline[Int(1)]
num_mbarsβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!