Skip to main content

Mojo module

correction_warp

Correction warp group logic for depth=256/512 pair-CTA SM100 attention.

Rescales the O accumulator in TMEM when the per-row maximum changes during online softmax.

Depth-dependent O layout: split_o=True (depth=512): O split into O_lo/O_hi (MMA_N=ov_depth/2 each, ov_depth/4 physical cols each). Two-phase rescale: O_lo then O_hi. split_o=False (depth=256): Single O (MMA_N=ov_depth, MMA_M*ov_depth/256 physical cols). Single-phase rescale.

All 128 threads participate, processing o_cols physical columns per phase.

Functions

Was this page helpful?