Mojo function
rms_norm_fused_fp8
rms_norm_fused_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, rank: Int, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width], /, target: StringSlice[StaticConstantOrigin] = "gpu", use_dynamic_scaling: Bool = True](shape: IndexList[rank], output: NDBuffer[out_dtype, rank, MutAnyOrigin], gamma: TileTensor[in_dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], ctx: DeviceContextPtr, scale_ub: Float32, static_scale: OptionalReg[Float32] = None, scale_output: OptionalReg[NDBuffer[scales_dtype, 1, MutAnyOrigin]] = None)
Fused RMSNorm + FP8 quantization kernel.
Computes RMSNorm normalization and quantizes the output to FP8 format in a single pass. This fusion eliminates intermediate memory writes and improves performance.
Note: This kernel always multiplies by gamma before quantizing to FP8, which is the correct behavior for FP8 quantization.
Parameters:
- in_dtype (
DType): Input data type (float32, float16, or bfloat16). - out_dtype (
DType): Output FP8 data type (float8_e4m3fn or float8_e4m3fnuz). - scales_dtype (
DType): Data type for scale factors (bfloat16, float16, or float32). - rank (
Int): Tensor rank. - input_fn (
fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[in_dtype, width]): Function to load input values. - target (
StringSlice): Target device ("gpu" or "cpu"). - use_dynamic_scaling (
Bool): If True, compute scale dynamically; if False, use static_scale.
Args:
- shape (
IndexList): Input tensor shape. - output (
NDBuffer): Output buffer to write FP8 quantized values. - gamma (
TileTensor): RMSNorm scale parameter (rank 1). - epsilon (
Scalar): Small constant for numerical stability. - weight_offset (
Scalar): Offset to add after normalization. - ctx (
DeviceContextPtr): Device context. - scale_ub (
Float32): Upper bound for dynamic scale factor to limit the scale value. - static_scale (
OptionalReg): Static FP8 scale factor (required if use_dynamic_scaling=False). - scale_output (
OptionalReg): Buffer to write dynamic scales (required if use_dynamic_scaling=True).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!