Skip to main content

Mojo function

rms_norm_fused_fp8

rms_norm_fused_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, rank: Int, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width], /, target: StringSlice[StaticConstantOrigin] = "gpu"](shape: IndexList[rank], output: NDBuffer[out_dtype, rank, output.origin, output.shape, output.strides, alignment2=output.alignment2, address_space=output.address_space, exclusive=output.exclusive], gamma: TileTensor[in_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_shape_types=gamma.element_shape_types], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], ctx: DeviceContextPtr, scale_ub: Float32, scale_output: NDBuffer[scales_dtype, rank, scale_output.origin, scale_output.shape, scale_output.strides, alignment2=scale_output.alignment2, address_space=scale_output.address_space, exclusive=scale_output.exclusive])

Fused RMSNorm + FP8 quantization kernel.

Computes RMSNorm normalization and quantizes the output to FP8 format in a single pass. This fusion eliminates intermediate memory writes and improves performance.

Note: This kernel always multiplies by gamma before quantizing to FP8, which is the correct behavior for FP8 quantization.

Parameters:

  • in_dtype (DType): Input data type (float32, float16, or bfloat16).
  • out_dtype (DType): Output FP8 data type (float8_e4m3fn or float8_e4m3fnuz).
  • scales_dtype (DType): Data type for scale factors (bfloat16, float16, or float32).
  • rank (Int): Tensor rank.
  • input_fn (fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width]): Function to load input values.
  • target (StringSlice): Target device ("gpu" or "cpu").

Args:

  • shape (IndexList): Input tensor shape.
  • output (NDBuffer): Output buffer to write FP8 quantized values.
  • gamma (TileTensor): RMSNorm scale parameter (rank 1).
  • epsilon (Scalar): Small constant for numerical stability.
  • weight_offset (Scalar): Offset to add after normalization.
  • ctx (DeviceContextPtr): Device context.
  • scale_ub (Float32): Upper bound for dynamic scale factor to limit the scale value.
  • scale_output (NDBuffer): Buffer to write per-row dynamic scales (rank-N, last dim = 1).

Was this page helpful?