Skip to main content

Mojo function

quantize_tensor_dynamic_scaled_fp8

quantize_tensor_dynamic_scaled_fp8[out_dtype: DType, in_dtype: DType, num_threads: Int = 256, pdl_level: PDLLevel = PDLLevel(1)](out_tensor: TileTensor[out_dtype, out_tensor.LayoutType, out_tensor.origin, address_space=out_tensor.address_space, linear_idx_type=out_tensor.linear_idx_type, element_size=out_tensor.element_size], in_tensor: TileTensor[in_dtype, in_tensor.LayoutType, in_tensor.origin, address_space=in_tensor.address_space, linear_idx_type=in_tensor.linear_idx_type, element_size=in_tensor.element_size], scale_global: TileTensor[DType.float32, scale_global.LayoutType, scale_global.origin, address_space=scale_global.address_space, linear_idx_type=scale_global.linear_idx_type, element_size=scale_global.element_size], ctx: DeviceContext)

Per-tensor dynamic FP8 quantization.

First reduces max |x| / max_finite[out_dtype] over the entire tensor (max_reduction_scale_kernel), then runs the same elementwise path as quantize_static_scaled_fp8 with the scale read from scale_global on device (no D2H round trip).

Args:

  • ​out_tensor (TileTensor): FP8 output, same shape as in_tensor.
  • ​in_tensor (TileTensor): BF16/FP16/FP32 input.
  • ​scale_global (TileTensor): Length-1 FP32 TileTensor holding the computed scale.
  • ​ctx (DeviceContext): Device context.

Raises:

If buffer sizes or tensor ranks are inconsistent with the above.

Was this page helpful?