Skip to main content

Mojo struct

FA4MiscMBars

@register_passable(trivial) struct FA4MiscMBars[*, num_qk_stages: Int = 1, num_pv_stages: Int = 1, num_kv_stages: Int = 2, separate_kv: Bool = True, use_order_barriers: Bool = True]

Manages all mbarrier resources for FA4.

This struct consolidates all mbarrier management including:

  • S barriers (score MMA synchronization)
  • C barriers (correction synchronization)
  • Order barriers (softmax ordering)
  • Q1Sync barriers (Q tile synchronization)
  • K/V pipeline barriers
  • O pipeline barriers

Memory layout (count=128 first, then count=1): [S0_cons] [S1_cons] [C0] [C1] [Order*] [O_cons] | [S0_prod] [S1_prod] [Q1Sync] [K] [V*] [O_prod] *Order barriers only present when use_order_barriers=True *V barriers only present when separate_kv=True

Parameters

  • num_qk_stages (Int): Number of stages for Q@K' MMA (K loading can be staged).
  • num_pv_stages (Int): Number of stages for P@V MMA (P writing can be staged).
  • num_kv_stages (Int): Number of KV buffer stages for double/triple buffering.
  • separate_kv (Bool): True for MHA (separate K/V barriers), False for MLA (unified KV).
  • use_order_barriers (Bool): When True, allocate order barriers to prevent softmax warp group overlap. When False, order barriers are omitted.

Fields

  • mbar_base (MBarType):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

C0_offset

comptime C0_offset = (2 * num_pv_stages)

C1_offset

comptime C1_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].C0_offset + 2)

K_barriers

comptime K_barriers = ((2 * num_qk_stages) * num_kv_stages)

K_offset

comptime K_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].Q1SyncIdx + num_qk_stages)

num_order_barriers

comptime num_order_barriers = 2 if use_order_barriers else 0

number_warpgroup_count

comptime number_warpgroup_count = FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].S0_producer_offset

O_consumer_offset

comptime O_consumer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].order_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].num_order_barriers)

O_producer_offset

comptime O_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].V_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].V_barriers)

order_offset

comptime order_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].C1_offset + 2)

Q1SyncIdx

comptime Q1SyncIdx = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].S1_producer_offset + 1)

S0_consumer_offset

comptime S0_consumer_offset = 0

S0_producer_offset

comptime S0_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].O_consumer_offset + 2)

S1_consumer_offset

comptime S1_consumer_offset = num_pv_stages

S1_producer_offset

comptime S1_producer_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].S0_producer_offset + 1)

size

comptime size = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].O_producer_offset + 2)

SPipelineConsumer

comptime SPipelineConsumer = RolePipeline[1, False, consumer_sub_stages=num_pv_stages]

SPipelineProducer

comptime SPipelineProducer = RolePipeline[1, consumer_sub_stages=num_pv_stages]

V_barriers

comptime V_barriers = (2 * num_kv_stages) if separate_kv else 0

V_offset

comptime V_offset = (FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].K_offset + FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].K_barriers)

Methods

__init__

__init__(mbar_base: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

init

init(self, *, lane_idx: Int32)

producer_s0

producer_s0(self) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].SPipelineProducer

Get S producer for warp group 0.

Returns:

FA4MiscMBars

producer_s1

producer_s1(self) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].SPipelineProducer

Get S producer for warp group 1.

Returns:

FA4MiscMBars

consumer_s

consumer_s(self, wg_idx: UInt32) -> FA4MiscMBars[num_qk_stages=num_qk_stages, num_pv_stages=num_pv_stages, num_kv_stages=num_kv_stages, separate_kv=separate_kv, use_order_barriers=use_order_barriers].SPipelineConsumer

Get S consumer for given warp group.

Returns:

FA4MiscMBars

consumer_c0

consumer_c0(self) -> RolePipeline[1, False]

Returns:

RolePipeline

consumer_c1

consumer_c1(self) -> RolePipeline[1, False]

Returns:

RolePipeline

producer_c

producer_c(self, wg_idx: UInt32) -> RolePipeline[1]

Returns:

RolePipeline

pipeline_order_wait

pipeline_order_wait(self, wg_idx: UInt32) -> MBarType

Returns:

MBarType

pipeline_order_arrive

pipeline_order_arrive(self, wg_idx: UInt32) -> MBarType

Returns:

MBarType

q1_wait_mbar

q1_wait_mbar(self) -> MBarType

Returns:

MBarType

get_k_mbars

get_k_mbars(self) -> MBarType

Returns base pointer for K pipeline barriers.

Returns:

MBarType

get_v_mbars

get_v_mbars(self) -> MBarType

Returns base pointer for V pipeline barriers (MHA only).

Returns:

MBarType

get_kv_mbars

get_kv_mbars(self) -> MBarType

Returns base pointer for unified KV pipeline barriers (MLA).

Returns:

MBarType

producer_o0

producer_o0(self) -> RolePipeline[1]

Get O producer for warp group 0.

Returns:

RolePipeline

producer_o1

producer_o1(self) -> RolePipeline[1]

Get O producer for warp group 1.

Returns:

RolePipeline

consumer_o

consumer_o(self) -> RolePipeline[2, False]

Get O consumer pipeline.

Returns:

RolePipeline

num_mbars

static num_mbars() -> UInt32

Returns:

UInt32

Was this page helpful?