Mojo module
allreduce_rmsnorm_fp8
Fused allreduce + RMSNorm + FP8 quantization kernel.
Combines P2P allreduce, RMSNorm normalization, and FP8 dynamic quantization into a single kernel launch. Data stays in registers throughout โ no global memory intermediate between allreduce and RMSNorm.
Design:
- P2P loads from all GPUs (like 1-stage allreduce's load_reduce)
- Accumulation in float32 registers (for bfloat16 inputs)
- RMSNorm computation (warp-tiling: persistent row loop)
- FP8 dynamic per-row quantization
Each block processes multiple rows via a grid-strided loop, allowing row counts beyond MAX_NUM_BLOCKS_UPPER_BOUND (512). Gamma weights are preloaded once and reused across all rows in the loop.
Functionsโ
- โ
allreduce_rmsnorm_fp8: Fused allreduce + RMSNorm + FP8 quantization.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!