Skip to main content

Mojo function

quantize_fp8_kernel_per_tensor

quantize_fp8_kernel_per_tensor[out_type: DType, scales_type: DType, in_type: DType, input_fn: def[width: Int, alignment: Int](row: Int, col: Int) capturing -> SIMD[in_type, width], num_threads: Int, group_size: Int, simd_width: Int, num_groups: Int, output_layout: TensorLayout, scales_layout: TensorLayout](output: TileTensor[out_type, output_layout, MutAnyOrigin], scales: TileTensor[scales_type, scales_layout, MutAnyOrigin], scale_ub: Scalar[scales_type], num_rows: Int)

Per-tensor FP8 quantize kernel.

Reads all per-group scales written by quantize_fp8_kernel (stored as scales[group_idx, row]), finds the tensor-wide maximum scale, and re-quantizes every element with that single scale.

Block (0, 0) thread 0 overwrites scales[0, 0] with the final per-tensor scale factor so the caller can read it back.

Was this page helpful?