Skip to main content

Mojo function

allreduce_rmsnorm_fp8

allreduce_rmsnorm_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, ngpus: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //](input_buffers: InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus], output: TileTensor[out_dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], gamma: TileTensor[in_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: TileTensor[scales_dtype, scale_output.LayoutType, scale_output.origin, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)

Fused allreduce + RMSNorm + FP8 quantization.

Combines P2P allreduce across GPUs, RMSNorm normalization, and FP8 dynamic quantization into a single kernel launch. Eliminates the global memory round-trip between allreduce output and RMSNorm input.

Note: This kernel does not issue an end barrier. The FP8 output and scale buffers are safe to read only on the local GPU (stream ordering guarantees visibility). If a remote GPU needs to read these outputs, the caller must insert an explicit barrier. The start barrier of the NEXT allreduce call protects the input buffers that are read by remote GPUs.

Signal buffer sizing: 1-stage path (payload < threshold): size_ofSignal only. 2-stage path (payload > threshold): size_ofSignal + ceildiv(rows, ngpus) * cols (fp8 data) + align_up(ceildiv(rows, ngpus) * sizeof(scales_dtype), simd_width) (scales + pad)

Parameters:

  • in_dtype (DType): Input data type (e.g. bfloat16).
  • out_dtype (DType): FP8 output data type (e.g. float8_e4m3fn).
  • scales_dtype (DType): Scale factor data type (e.g. float32).
  • ngpus (Int): Number of GPUs participating.
  • in_layout (TensorLayout): Layout of the input TileTensors.
  • in_origin (Origin): Origin of the input TileTensors.

Args:

  • input_buffers (InlineArray): Per-GPU input buffers as TileTensors.
  • output (TileTensor): Output buffer for FP8 values as a TileTensor.
  • gamma (TileTensor): RMSNorm gamma weights (1D TileTensor of length cols).
  • epsilon (Scalar): RMSNorm epsilon for numerical stability.
  • weight_offset (Scalar): Additive offset for gamma weights.
  • scale_ub (Float32): Upper bound for FP8 scale clamping.
  • scale_output (TileTensor): Output buffer for per-row FP8 scales as a TileTensor.
  • rank_sigs (InlineArray): Per-GPU signal pointers for synchronization.
  • ctx (DeviceContext): Device context for this GPU.

Was this page helpful?