IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

rms_norm_fused_fp8

def rms_norm_fused_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, rank: Int, input_fn: def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width], /, target: StringSlice[StaticConstantOrigin] = StringSlice("gpu"), compile_only: Bool = False](shape: IndexList[rank], output: TileTensor[out_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], gamma: TileTensor[in_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], ctx: DeviceContext, scale_ub: Float32, scale_output: TileTensor[scales_dtype, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size])

Fused RMSNorm + FP8 quantization kernel (TileTensor overload).

Computes RMSNorm normalization and quantizes the output to FP8 format in a single pass. This is the primary implementation that operates on TileTensor inputs.

Parameters:

  • ​in_dtype (DType): Input data type (float32, float16, or bfloat16).
  • ​out_dtype (DType): Output FP8 data type (float8_e4m3fn or float8_e4m3fnuz).
  • ​scales_dtype (DType): Data type for scale factors (bfloat16, float16, or float32).
  • ​rank (Int): Tensor rank.
  • ​input_fn (def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width]): Function to load input values.
  • ​target (StringSlice[StaticConstantOrigin]): Target device ("gpu" or "cpu").
  • ​compile_only (Bool): If True, only compiles the kernel without executing it. Used to pre-compile kernels and avoid JIT compilation deadlocks in multi-GPU contexts.

Args: