Skip to main content

Mojo module

allreduce_residual_rmsnorm_fp8

Fused allreduce + RMSNorm + FP8 quantization kernel.

Combines P2P allreduce, RMSNorm normalization, and FP8 dynamic quantization into a single kernel launch. Data stays in registers throughout — no global memory intermediate between allreduce and RMSNorm.

Three dispatch paths:

1-stage (small payloads): Each thread loads from all GPUs, accumulates in float32 registers, then applies RMSNorm + FP8 quantization. Simple but O(N x ngpus) P2P traffic.

2-stage (large payloads, non-residual or medium residual): Single kernel with two stages separated by a per-block barrier: Stage 1 (RS + RMSNorm + FP8): each block reduces rows in its partition from all GPUs in f32 registers, then normalizes and quantizes to FP8 immediately — no f32 scratch. Writes compact fp8 data + per-row scale (+ optional bf16 residual) to local scratch. Stage 2 (AG copy): each block reads compact fp8/scale/bf16 from the owning GPU's scratch (P2P) and writes to local output buffers. No compute — just copies of data that is 4x smaller than f32. Both stages use the same row-to-block mapping, so the per-block barrier correctly synchronizes all data dependencies. Total P2P traffic is N * sizeof(in_dtype) for Stage 1 (same as 1-stage) plus N * sizeof(fp8) [+ N * sizeof(in_dtype) if residual] for Stage 2. The AG phase copies fp8 (1 byte/elem) instead of f32 (4 bytes/elem), reducing Stage 2 bandwidth by ~4x for the non-residual case.

Split (large residual payloads): Two separate kernel launches — allreduce with add epilogue followed by fused rmsnorm+fp8. Avoids carrying bf16 residual data through scratch buffers in both stages, which nearly doubles Stage 2 copy bandwidth in the fused path. The caller-provided residual_output buffer serves as the intermediate, so no extra allocation is needed.

Each block processes multiple rows via a grid-strided loop, allowing row counts beyond the hardware-tuned block limit. Gamma weights are preloaded once and reused across all rows in the loop.

Functions

Was this page helpful?