Mojo function
quantize_fp8_kernel_per_tensor
quantize_fp8_kernel_per_tensor[out_type: DType, scales_type: DType, in_type: DType, input_fn: def[width: Int, alignment: Int](row: Int, col: Int) capturing -> SIMD[in_type, width], num_threads: Int, group_size: Int, simd_width: Int, num_groups: Int, output_layout: TensorLayout, scales_layout: TensorLayout](output: TileTensor[out_type, output_layout, MutAnyOrigin], scales: TileTensor[scales_type, scales_layout, MutAnyOrigin], scale_ub: Scalar[scales_type], num_rows: Int)
Per-tensor FP8 quantize kernel.
Reads all per-group scales written by quantize_fp8_kernel (stored
as scales[group_idx, row]), finds the tensor-wide maximum scale,
and re-quantizes every element with that single scale.
Block (0, 0) thread 0 overwrites scales[0, 0] with the final
per-tensor scale factor so the caller can read it back.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!