For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
allreduce_rmsnorm
def allreduce_rmsnorm[in_dtype: DType, out_dtype: DType, scales_dtype: DType, ngpus: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //](input_buffers: InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus], output: TileTensor[out_dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], gamma: TileTensor[in_dtype, Storage=gamma.Storage, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: TileTensor[scales_dtype, Storage=scale_output.Storage, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], Int(8)], ctx: DeviceContext)
Fused allreduce + RMSNorm with optional FP8 quantization.
Combines a P2P allreduce across GPUs, RMSNorm, and β when the output dtype
differs from the input dtype β FP8 dynamic quantization into a single
kernel launch, eliminating the global memory round-trip between allreduce
output and RMSNorm input. When out_dtype == in_dtype the quantization is
skipped: the normalized value is written directly in the input dtype and
scale_ub and scale_output are ignored.
Note: This kernel does not issue an end barrier. The output and scale buffers are safe to read only on the local GPU (stream ordering guarantees visibility). If a remote GPU needs to read these outputs, the caller must insert an explicit barrier. The start barrier of the NEXT allreduce call protects the input buffers that are read by remote GPUs.
Signal buffer sizing: 1-stage path (payload < threshold): size_ofSignal only. 2-stage path (payload > threshold): size_ofSignal + ceildiv(rows, ngpus) * cols * sizeof(out_dtype) (output) + align_up(ceildiv(rows, ngpus) * sizeof(scales_dtype), simd_width) (scales + pad, if quantizing)
Parameters:
- βin_dtype (
DType): Input data type (e.g. bfloat16). - βout_dtype (
DType): Output data type. Either a float8 type (fuses quantization) or equal toin_dtype(no quantization). - βscales_dtype (
DType): Scale factor data type (e.g. float32). Ignored whenout_dtype == in_dtype. - βngpus (
Int): Number of GPUs participating. - βin_layout (
TensorLayout): Layout of the input TileTensors. - βin_origin (
Origin[mut=in_origin.mut]): Origin of the input TileTensors.
Args:
- βinput_buffers (
InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus]): Per-GPU input buffers as TileTensors. - βoutput (
TileTensor[out_dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output buffer (FP8 values when quantizing, elsein_dtype). - βgamma (
TileTensor[in_dtype, Storage=gamma.Storage, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size]): RMSNorm gamma weights (1D TileTensor of length cols). - βepsilon (
Scalar[in_dtype]): RMSNorm epsilon for numerical stability. - βweight_offset (
Scalar[in_dtype]): Additive offset for gamma weights. - βscale_ub (
Float32): Upper bound for FP8 scale clamping (ignored when not quantizing). - βscale_output (
TileTensor[scales_dtype, Storage=scale_output.Storage, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size]): Output buffer for per-row FP8 scales (ignored, and not written, when not quantizing). - βrank_sigs (
InlineArray[UnsafePointer[Signal, MutAnyOrigin], Int(8)]): Per-GPU signal pointers for synchronization. - βctx (
DeviceContext): Device context for this GPU.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!