Mojo module
correction_warp
Correction warp group logic for depth=256/512 pair-CTA SM100 attention.
Rescales the O accumulator in TMEM when the per-row maximum changes during online softmax.
Depth-dependent O layout: split_o=True (depth=512): O split into O_lo/O_hi (MMA_N=ov_depth/2 each, ov_depth/4 physical cols each). Two-phase rescale: O_lo then O_hi. split_o=False (depth=256): Single O (MMA_N=ov_depth, MMA_M*ov_depth/256 physical cols). Single-phase rescale.
All 128 threads participate, processing o_cols physical columns per phase.
Functions
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!