Skip to main content

Mojo function

allreduce_residual_rmsnorm_fp8

allreduce_residual_rmsnorm_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, ngpus: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //](input_buffers: InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus], residual: TileTensor[in_dtype, residual.LayoutType, residual.origin, address_space=residual.address_space, linear_idx_type=residual.linear_idx_type, element_size=residual.element_size], output: TileTensor[out_dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], residual_output: TileTensor[in_dtype, residual_output.LayoutType, residual_output.origin, address_space=residual_output.address_space, linear_idx_type=residual_output.linear_idx_type, element_size=residual_output.element_size], gamma: TileTensor[in_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: TileTensor[scales_dtype, scale_output.LayoutType, scale_output.origin, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)

TileTensor primary implementation of allreduce_residual_rmsnorm_fp8.

Parameters:

  • in_dtype (DType): Input data type (e.g. bfloat16).
  • out_dtype (DType): FP8 output data type (e.g. float8_e4m3fn).
  • scales_dtype (DType): Scale factor data type (e.g. float32).
  • ngpus (Int): Number of GPUs participating.
  • in_layout (TensorLayout): Layout of the input TileTensors.
  • in_origin (Origin): Origin of the input TileTensors.

Args:

  • input_buffers (InlineArray): Per-GPU input buffers as TileTensors.
  • residual (TileTensor): Residual buffer as a TileTensor.
  • output (TileTensor): Output buffer for FP8 values as a TileTensor.
  • residual_output (TileTensor): Output buffer for pre-norm sum as a TileTensor.
  • gamma (TileTensor): RMSNorm gamma weights (1D TileTensor).
  • epsilon (Scalar): RMSNorm epsilon for numerical stability.
  • weight_offset (Scalar): Additive offset for gamma weights.
  • scale_ub (Float32): Upper bound for FP8 scale clamping.
  • scale_output (TileTensor): Output buffer for per-row FP8 scales as a TileTensor.
  • rank_sigs (InlineArray): Per-GPU signal pointers for synchronization.
  • ctx (DeviceContext): Device context for this GPU.

Was this page helpful?