Mojo function
quantize_tensor_dynamic_scaled_fp8
quantize_tensor_dynamic_scaled_fp8[out_dtype: DType, in_dtype: DType, num_threads: Int = 256, pdl_level: PDLLevel = PDLLevel(1)](out_tensor: TileTensor[out_dtype, out_tensor.LayoutType, out_tensor.origin, address_space=out_tensor.address_space, linear_idx_type=out_tensor.linear_idx_type, element_size=out_tensor.element_size], in_tensor: TileTensor[in_dtype, in_tensor.LayoutType, in_tensor.origin, address_space=in_tensor.address_space, linear_idx_type=in_tensor.linear_idx_type, element_size=in_tensor.element_size], scale_global: TileTensor[DType.float32, scale_global.LayoutType, scale_global.origin, address_space=scale_global.address_space, linear_idx_type=scale_global.linear_idx_type, element_size=scale_global.element_size], ctx: DeviceContext)
Per-tensor dynamic FP8 quantization.
First reduces max |x| / max_finite[out_dtype] over the entire tensor
(max_reduction_scale_kernel), then runs the same elementwise path as
quantize_static_scaled_fp8 with the scale read from scale_global on
device (no D2H round trip).
Args:
- βout_tensor (
TileTensor): FP8 output, same shape asin_tensor. - βin_tensor (
TileTensor): BF16/FP16/FP32 input. - βscale_global (
TileTensor): Length-1 FP32TileTensorholding the computed scale. - βctx (
DeviceContext): Device context.
Raises:
If buffer sizes or tensor ranks are inconsistent with the above.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!