Mojo struct
FA4MiscMBars
struct FA4MiscMBars[*, num_qk_stages: Int = 1, num_pv_stages: Int = 1, num_kv_stages: Int = 2, use_order_barriers: Bool = True, use_fused_kv: Bool = False, pair_cta: Bool = False]
Manages all mbarrier resources for FA4.
This struct consolidates all mbarrier management including:
- S barriers (score MMA synchronization)
- C barriers (correction synchronization)
- Order barriers (softmax ordering)
- Q1Sync barriers (Q tile synchronization)
- K/V pipeline barriers (separate K and V)
- O pipeline barriers
Memory layout (count=128 first, then count=1): [S0_cons] [S1_cons] [C0] [C1] [Order*] | [S0_prod] [S1_prod] [Q1Sync] [K] [V] [O_prod] *Order barriers only present when use_order_barriers=True
Parametersβ
- βnum_qk_stages (
Int): Number of stages for Q@K' MMA (K loading can be staged). - βnum_pv_stages (
Int): Number of stages for P@V MMA (P writing can be staged). - βnum_kv_stages (
Int): Number of KV buffer stages for double/triple buffering. - βuse_order_barriers (
Bool): When True, allocate order barriers to prevent softmax warp group overlap. When False, order barriers are omitted. - βuse_fused_kv (
Bool): Whether the K and V share the same pipeline, or separate. - βpair_cta (
Bool): Whether to use 1-cta or 2-cta implementation.
Fieldsβ
- βmbar_base (
MBarType):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
C0_offsetβ
comptime C0_offset = (2 * num_pv_stages)
C1_offsetβ
comptime C1_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].C0_offset + 2)
K_barriersβ
comptime K_barriers = ((2 * num_qk_stages) * num_kv_stages)
K_offsetβ
comptime K_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].Q1SyncIdx + num_qk_stages)
num_order_barriersβ
comptime num_order_barriers = 2 if use_order_barriers else 0
number_warpgroup_countβ
comptime number_warpgroup_count = FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].S0_producer_offset
O_producer_offsetβ
comptime O_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].V_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].V_barriers)
order_offsetβ
comptime order_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].C1_offset + 2)
Q1SyncIdxβ
comptime Q1SyncIdx = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].S1_producer_offset + 1)
S0_consumer_offsetβ
comptime S0_consumer_offset = 0
S0_producer_offsetβ
comptime S0_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].order_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].num_order_barriers)
S1_consumer_offsetβ
comptime S1_consumer_offset = num_pv_stages
S1_producer_offsetβ
comptime S1_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].S0_producer_offset + 1)
sizeβ
comptime size = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].O_producer_offset + 2)
SPipelineConsumerβ
comptime SPipelineConsumer = RolePipeline[1, False, consumer_sub_stages=num_pv_stages]
SPipelineProducerβ
comptime SPipelineProducer = RolePipeline[1, consumer_sub_stages=num_pv_stages]
V_barriersβ
comptime V_barriers = 0 if use_fused_kv else (2 * num_kv_stages)
V_offsetβ
comptime V_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].K_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].K_barriers)
Methodsβ
__init__β
__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
initβ
init(self, *, lane_idx: Int32)
producer_s0β
producer_s0(self) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].SPipelineProducer
Get S producer for warp group 0.
Returns:
FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].SPipelineProducer
producer_s1β
producer_s1(self) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].SPipelineProducer
Get S producer for warp group 1.
Returns:
FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].SPipelineProducer
consumer_sβ
consumer_s(self, wg_idx: UInt32) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].SPipelineConsumer
Get S consumer for given warp group.
Returns:
FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=use_fused_kv, pair_cta=pair_cta].SPipelineConsumer
consumer_c0β
consumer_c0(self) -> RolePipeline[1, False]
Returns:
RolePipeline[1, False]
consumer_c1β
consumer_c1(self) -> RolePipeline[1, False]
Returns:
RolePipeline[1, False]
producer_cβ
producer_c(self, wg_idx: UInt32) -> RolePipeline[1]
Returns:
RolePipeline[1]
pipeline_order_waitβ
pipeline_order_wait(self, wg_idx: UInt32) -> MBarType
Returns:
MBarType
pipeline_order_arriveβ
pipeline_order_arrive(self, wg_idx: UInt32) -> MBarType
Returns:
MBarType
q1_wait_mbarβ
q1_wait_mbar(self) -> MBarType
Returns:
MBarType
get_k_mbarsβ
get_k_mbars(self) -> MBarType
Returns base pointer for K pipeline barriers.
Returns:
MBarType
get_v_mbarsβ
get_v_mbars(self) -> MBarType
Returns base pointer for V pipeline barriers. In fused mode, returns the same as get_k_mbars (shared pipeline).
Returns:
MBarType
combined_p_o_consumerβ
combined_p_o_consumer(self, wg_idx: UInt32) -> MBarType
Combined P+O consumer barrier for given warp group.
Arrived at by BOTH softmax (P ready) and correction (O rescaled). Returns S_consumer[0] for wg_idx=0 or wg_idx=1.
Returns:
MBarType
consumer_oβ
consumer_o(self) -> RolePipeline[2, False, consumer_sub_stages=num_pv_stages]
Get O consumer pipeline.
Wait side: O_producer barriers (stride 1, indexed by stage). Release side: combined S+O barriers (S_consumer[0] per wg, stride num_pv_stages).
Returns:
RolePipeline[2, False, consumer_sub_stages=num_pv_stages]
producer_o0β
producer_o0(self) -> RolePipeline[1]
Get O producer for warp group 0.
Returns:
RolePipeline[1]
producer_o1β
producer_o1(self) -> RolePipeline[1]
Get O producer for warp group 1.
Returns:
RolePipeline[1]
num_mbarsβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!