IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

allreduce_rmsnorm

def allreduce_rmsnorm[in_dtype: DType, out_dtype: DType, scales_dtype: DType, ngpus: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //](input_buffers: InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus], output: TileTensor[out_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], gamma: TileTensor[in_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: TileTensor[scales_dtype, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)

Fused allreduce + RMSNorm with optional FP8 quantization.

Combines a P2P allreduce across GPUs, RMSNorm, and β€” when the output dtype differs from the input dtype β€” FP8 dynamic quantization into a single kernel launch, eliminating the global memory round-trip between allreduce output and RMSNorm input. When out_dtype == in_dtype the quantization is skipped: the normalized value is written directly in the input dtype and scale_ub and scale_output are ignored.

Note: This kernel does not issue an end barrier. The output and scale buffers are safe to read only on the local GPU (stream ordering guarantees visibility). If a remote GPU needs to read these outputs, the caller must insert an explicit barrier. The start barrier of the NEXT allreduce call protects the input buffers that are read by remote GPUs.

Signal buffer sizing: 1-stage path (payload < threshold): size_ofSignal only. 2-stage path (payload > threshold): size_ofSignal + ceildiv(rows, ngpus) * cols * sizeof(out_dtype) (output) + align_up(ceildiv(rows, ngpus) * sizeof(scales_dtype), simd_width) (scales + pad, if quantizing)

Parameters:

  • ​in_dtype (DType): Input data type (e.g. bfloat16).
  • ​out_dtype (DType): Output data type. Either a float8 type (fuses quantization) or equal to in_dtype (no quantization).
  • ​scales_dtype (DType): Scale factor data type (e.g. float32). Ignored when out_dtype == in_dtype.
  • ​ngpus (Int): Number of GPUs participating.
  • ​in_layout (TensorLayout): Layout of the input TileTensors.
  • ​in_origin (Origin[mut=in_origin.mut]): Origin of the input TileTensors.

Args: