Skip to main content

Mojo function

rms_norm_fused_fp8

rms_norm_fused_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, rank: Int, input_fn: def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width], /, target: StringSlice[StaticConstantOrigin] = StringSlice("gpu"), compile_only: Bool = False](shape: IndexList[rank], output: TileTensor[out_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], gamma: TileTensor[in_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], ctx: DeviceContextPtr, scale_ub: Float32, scale_output: TileTensor[scales_dtype, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size])

Fused RMSNorm + FP8 quantization kernel (TileTensor overload).

Computes RMSNorm normalization and quantizes the output to FP8 format in a single pass. This is the primary implementation that operates on TileTensor inputs.

Parameters:

  • ​in_dtype (DType): Input data type (float32, float16, or bfloat16).
  • ​out_dtype (DType): Output FP8 data type (float8_e4m3fn or float8_e4m3fnuz).
  • ​scales_dtype (DType): Data type for scale factors (bfloat16, float16, or float32).
  • ​rank (Int): Tensor rank.
  • ​input_fn (def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width]): Function to load input values.
  • ​target (StringSlice[StaticConstantOrigin]): Target device ("gpu" or "cpu").
  • ​compile_only (Bool): If True, only compiles the kernel without executing it. Used to pre-compile kernels and avoid JIT compilation deadlocks in multi-GPU contexts.

Args: