Mojo module
correction_warp
Correction warp group logic for depth=512 pair-CTA SM100 attention.
Rescales the O accumulator in TMEM when the per-row maximum changes during online softmax. O is split into two halves (O_lo, O_hi) produced by separate P@V MMA instructions (MMA_N=ov_depth/2 each). The correction warp processes O_lo first, releasing PO_lo to unblock the next iteration's P@V_lo, then processes O_hi and releases PO_hi.
Pair-CTA TMEM column layout for O (4 quadrants): First MMA (O_lo, TMEM base = TMEM_O, ov_depth/4 physical cols): Rows 0-63: O logical cols 0 .. ov_depth/4 - 1 Rows 64-127: O logical cols ov_depth/4 .. ov_depth/2 - 1 Second MMA (O_hi, TMEM base = TMEM_O_hi, ov_depth/4 physical cols): Rows 0-63: O logical cols ov_depth/2 .. 3ov_depth/4 - 1 Rows 64-127: O logical cols 3ov_depth/4 .. ov_depth - 1
Both row groups access the SAME physical TMEM column range within each MMA region โ the pair-CTA layout distinguishes them by row, not column address. All 128 threads participate in both phases, each processing ov_depth/4 physical columns per phase.
Functionsโ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!