Skip to main content

Mojo module

allreduce_rmsnorm_fp8

Fused allreduce + RMSNorm + FP8 quantization kernel.

Combines P2P allreduce, RMSNorm normalization, and FP8 dynamic quantization into a single kernel launch. Data stays in registers throughout โ€” no global memory intermediate between allreduce and RMSNorm.

Design:

  1. P2P loads from all GPUs (like 1-stage allreduce's load_reduce)
  2. Accumulation in float32 registers (for bfloat16 inputs)
  3. RMSNorm computation (warp-tiling: persistent row loop)
  4. FP8 dynamic per-row quantization

Each block processes multiple rows via a grid-strided loop, allowing row counts beyond MAX_NUM_BLOCKS_UPPER_BOUND (512). Gamma weights are preloaded once and reused across all rows in the loop.

Functionsโ€‹

Was this page helpful?