Mojo module
barriers
Barrier infrastructure for depth=512 pair-CTA SM100 attention kernels.
Manages all mbarrier resources for the warp-specialized kernel. Each CTA has 12 warps (384 threads, 3 warp groups of 128) divided into 4 warp kinds: softmax (warps 0-3, 128 threads), correction (warps 4-7, 128 threads), MMA (warp 8), load (warp 9), and spare (warps 10-11).
O is split into two halves (O_lo, O_hi) produced by separate P@V MMA instructions (MMA_N=ov_depth/2 each). Each half has its own commit barrier (O_mma_lo, O_mma_hi) and ready barrier (PO_lo, PO_hi), enabling the correction warp to rescale O_lo while the MMA is still producing O_hi.
Barrier layout (low to high index): [0] PO_lo (count-256: softmax 128 + correction 128) [1] PO_hi (count-128: correction 128 only) [2] S_even consumer (count-128: softmax 4 warps) [3] S_odd consumer (count-128: softmax 4 warps) [4] C producer (count-128: softmax 4 warps) [5] C consumer (count-128: correction 4 warps) [6] S_even producer (count-1: MMA mma_arrive) [7] S_odd producer (count-1: MMA mma_arrive) [8] O_mma_lo (count-1: MMA mma_arrive after V_lo stages) [9] O_mma_hi (count-1: MMA mma_arrive after V_hi stages) [10..] KV pipeline (count-1: 2 per stage, producer+consumer pairs)
Synchronization flow per KV iteration (even S buffer): Load: TMA K/V sub-tile → arrive KV[stage] producer MMA: wait KV + S_even consumer → Q@K'→S → mma_arrive S_even producer wait KV + PO_lo → P@V_lo→O_lo → mma_arrive O_mma_lo wait PO_hi → P@V_hi→O_hi (reuse KV slots) → mma_arrive O_mma_hi Softmax: wait S_even producer → load S TMEM→regs → arrive S_even consumer exp(S) → write P to SMEM → arrive PO_lo write correction → arrive C producer Correction: wait C producer → read correction wait O_mma_lo → rescale O_lo → arrive PO_lo wait O_mma_hi → rescale O_hi → arrive PO_hi → arrive C consumer
Structs
-
Depth512MBars: Manages all mbarrier resources for depth=512 pair-CTA attention.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!