IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

allreduce_residual_rmsnorm

Fused allreduce + RMSNorm (+ optional FP8 quantization) kernel.

Combines P2P allreduce, RMSNorm normalization, and optional FP8 dynamic quantization into a single kernel launch. Data stays in registers throughout β€” no global memory intermediate between allreduce and RMSNorm.

Quantization is fused in only when the output dtype differs from the input dtype (out_dtype != in_dtype). When out_dtype == in_dtype (e.g. a bf16 input with a bf16 output) the FP8 phases β€” per-row max, dynamic scale, and quantize β€” are skipped at compile time, the normalized value is written directly in the input dtype, and no per-row scale is produced.

Three dispatch paths:

1-stage (small payloads): Each thread loads from all GPUs, accumulates in float32 registers, then applies RMSNorm + FP8 quantization. Simple but O(N x ngpus) P2P traffic.

2-stage (large payloads, non-residual or medium residual): Single kernel with two stages separated by a per-block barrier: Stage 1 (RS + RMSNorm + FP8): each block reduces rows in its partition from all GPUs in f32 registers, then normalizes and quantizes to FP8 immediately β€” no f32 scratch. Writes compact fp8 data + per-row scale (+ optional bf16 residual) to local scratch. Stage 2 (AG copy): each block reads compact fp8/scale/bf16 from the owning GPU's scratch (P2P) and writes to local output buffers. No compute β€” just copies of data that is 4x smaller than f32. Both stages use the same row-to-block mapping, so the per-block barrier correctly synchronizes all data dependencies. Total P2P traffic is N * sizeof(in_dtype) for Stage 1 (same as 1-stage) plus N * sizeof(fp8) [+ N * sizeof(in_dtype) if residual] for Stage 2. The AG phase copies fp8 (1 byte/elem) instead of f32 (4 bytes/elem), reducing Stage 2 bandwidth by ~4x for the non-residual case.

Split (large residual payloads): Two separate kernel launches β€” allreduce with add epilogue followed by fused rmsnorm+fp8. Avoids carrying bf16 residual data through scratch buffers in both stages, which nearly doubles Stage 2 copy bandwidth in the fused path. The caller-provided residual_output buffer serves as the intermediate, so no extra allocation is needed.

Each block processes multiple rows via a grid-strided loop, allowing row counts beyond the hardware-tuned block limit. Gamma weights are preloaded once and reused across all rows in the loop.

Functions​