Mojo module
barriers
Barrier infrastructure for depth=256/512 pair-CTA SM100 attention kernels.
Manages all mbarrier resources for the warp-specialized kernel. Each CTA has 12 warps (384 threads, 3 warp groups of 128) divided into 4 warp kinds: softmax (warps 0-3, 128 threads), correction (warps 4-7, 128 threads), MMA (warp 8), load (warp 9), and spare (warps 10-11).
When split_o=True (depth=512), O is split into O_lo/O_hi with separate commit barriers (O_mma_lo, O_mma_hi) and ready barriers (PO_lo, PO_hi). When split_o=False (depth=256), PO_hi and O_mma_hi are eliminated (8 fixed barriers instead of 10).
Barrier layout (split_o=True, 10 fixed): [0] PO_lo (count-512: softmax 128×2 CTAs + correction 128×2 CTAs) [1] PO_hi (count-256: correction 128×2 CTAs) [2] S_even consumer (count-256: softmax 128×2 CTAs) [3] S_odd consumer (count-256: softmax 128×2 CTAs) [4] C producer (count-128: softmax 4 warps, CTA-local) [5] C consumer (count-128: correction 4 warps, CTA-local) [6] S_even producer (count-1: MMA mma_arrive) [7] S_odd producer (count-1: MMA mma_arrive) [8] O_mma_lo (count-1: MMA mma_arrive after V_lo stages) [9] O_mma_hi (count-1: MMA mma_arrive after V_hi stages) [10..] KV pipeline (count-1: 2 per stage, producer+consumer pairs)
Barrier layout (split_o=False, 8 fixed): [0] PO_lo (count-512) [1] S_even consumer (count-256) [2] S_odd consumer (count-256) [3] C producer (count-128) [4] C consumer (count-128) [5] S_even producer (count-1) [6] S_odd producer (count-1) [7] O_mma_lo (count-1) [8..] KV pipeline (count-1: 2 per stage)
Synchronization flow per KV iteration (even S buffer): Load: TMA K/V sub-tile → arrive KV[stage] producer MMA: wait KV + S_even consumer → Q@K'→S → mma_arrive S_even producer wait KV + PO_lo → P@V_lo→O_lo → mma_arrive O_mma_lo wait PO_hi → P@V_hi→O_hi (reuse KV slots) → mma_arrive O_mma_hi Softmax: wait S_even producer → load S TMEM→regs → arrive S_even consumer exp(S) → write P to SMEM → arrive PO_lo write correction → arrive C producer Correction: wait C producer → read correction wait O_mma_lo → rescale O_lo → arrive PO_lo wait O_mma_hi → rescale O_hi → arrive PO_hi → arrive C consumer
Structs
-
Depth512MBars: Manages all mbarrier resources for depth=256/512 pair-CTA attention.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!