Skip to main content

Mojo function

max_reduction_scale_kernel

max_reduction_scale_kernel[in_dtype: DType, out_dtype: DType, input_layout: TensorLayout, scale_layout: TensorLayout, num_threads: Int](scale_global: TileTensor[DType.float32, scale_layout, MutAnyOrigin], input_tensor: TileTensor[in_dtype, input_layout, MutAnyOrigin])

Per-row strided max-|x| reduction into a global FP8 scale.

One block scans one row: threads stride across the hidden dimension, reduce to a row-wise max absolute value, then thread 0 atomically updates scale_global with row_max / max_finite[out_dtype].

Args:

  • scale_global (TileTensor): Length-1 FP32 TileTensor; must be zero before launch.
  • input_tensor (TileTensor): Rank-2 input.

Was this page helpful?