Mojo struct
FA4MiscMBars
@register_passable(trivial)
struct FA4MiscMBars[*, num_qk_stages: Int = 1, num_pv_stages: Int = 1, num_kv_stages: Int = 2, separate_kv: Bool = True, use_order_barriers: Bool = True]
Manages all mbarrier resources for FA4.
This struct consolidates all mbarrier management including:
- S barriers (score MMA synchronization)
- C barriers (correction synchronization)
- Order barriers (softmax ordering)
- Q1Sync barriers (Q tile synchronization)
- K/V pipeline barriers
- O pipeline barriers
Memory layout (count=128 first, then count=1): [S0_cons] [S1_cons] [C0] [C1] [Order*] [O_cons] | [S0_prod] [S1_prod] [Q1Sync] [K] [V*] [O_prod] *Order barriers only present when use_order_barriers=True *V barriers only present when separate_kv=True
Parameters
- num_qk_stages (
Int): Number of stages for Q@K' MMA (K loading can be staged). - num_pv_stages (
Int): Number of stages for P@V MMA (P writing can be staged). - num_kv_stages (
Int): Number of KV buffer stages for double/triple buffering. - separate_kv (
Bool): True for MHA (separate K/V barriers), False for MLA (unified KV). - use_order_barriers (
Bool): When True, allocate order barriers to prevent softmax warp group overlap. When False, order barriers are omitted.
Fields
- mbar_base (
MBarType):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
C0_offset
comptime C0_offset = (2 * num_pv_stages)
C1_offset
comptime C1_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].C0_offset + 2)
K_barriers
comptime K_barriers = ((2 * num_qk_stages) * num_kv_stages)
K_offset
comptime K_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].Q1SyncIdx + num_qk_stages)
num_order_barriers
comptime num_order_barriers = 2 if use_order_barriers else 0
number_warpgroup_count
comptime number_warpgroup_count = FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].S0_producer_offset
O_consumer_offset
comptime O_consumer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].order_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].num_order_barriers)
O_producer_offset
comptime O_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].V_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].V_barriers)
order_offset
comptime order_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].C1_offset + 2)
Q1SyncIdx
comptime Q1SyncIdx = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].S1_producer_offset + 1)
S0_consumer_offset
comptime S0_consumer_offset = 0
S0_producer_offset
comptime S0_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].O_consumer_offset + 2)
S1_consumer_offset
comptime S1_consumer_offset = num_pv_stages
S1_producer_offset
comptime S1_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].S0_producer_offset + 1)
size
comptime size = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].O_producer_offset + 2)
SPipelineConsumer
comptime SPipelineConsumer = RolePipeline[1, False, consumer_sub_stages=num_pv_stages]
SPipelineProducer
comptime SPipelineProducer = RolePipeline[1, consumer_sub_stages=num_pv_stages]
V_barriers
comptime V_barriers = (2 * num_kv_stages) if separate_kv else 0
V_offset
comptime V_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].K_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].K_barriers)
Methods
__init__
__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
init
init(self, *, lane_idx: Int32)
producer_s0
producer_s0(self) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].SPipelineProducer
Get S producer for warp group 0.
Returns:
FA4MiscMBars
producer_s1
producer_s1(self) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].SPipelineProducer
Get S producer for warp group 1.
Returns:
FA4MiscMBars
consumer_s
consumer_s(self, wg_idx: UInt32) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].SPipelineConsumer
Get S consumer for given warp group.
Returns:
FA4MiscMBars
consumer_c0
consumer_c0(self) -> RolePipeline[1, False]
Returns:
RolePipeline
consumer_c1
consumer_c1(self) -> RolePipeline[1, False]
Returns:
RolePipeline
producer_c
producer_c(self, wg_idx: UInt32) -> RolePipeline[1]
Returns:
RolePipeline
pipeline_order_wait
pipeline_order_wait(self, wg_idx: UInt32) -> MBarType
Returns:
MBarType
pipeline_order_arrive
pipeline_order_arrive(self, wg_idx: UInt32) -> MBarType
Returns:
MBarType
q1_wait_mbar
q1_wait_mbar(self) -> MBarType
Returns:
MBarType
get_k_mbars
get_k_mbars(self) -> MBarType
Returns base pointer for K pipeline barriers.
Returns:
MBarType
get_v_mbars
get_v_mbars(self) -> MBarType
Returns base pointer for V pipeline barriers (MHA only).
Returns:
MBarType
get_kv_mbars
get_kv_mbars(self) -> MBarType
Returns base pointer for unified KV pipeline barriers (MLA).
Returns:
MBarType
producer_o0
producer_o0(self) -> RolePipeline[1]
Get O producer for warp group 0.
Returns:
RolePipeline
producer_o1
producer_o1(self) -> RolePipeline[1]
Get O producer for warp group 1.
Returns:
RolePipeline
consumer_o
consumer_o(self) -> RolePipeline[2, False]
Get O consumer pipeline.
Returns:
RolePipeline
num_mbars
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!