Mojo function
max_reduction_scale_kernel
max_reduction_scale_kernel[in_dtype: DType, out_dtype: DType, input_layout: TensorLayout, scale_layout: TensorLayout, num_threads: Int](scale_global: TileTensor[DType.float32, scale_layout, MutAnyOrigin], input_tensor: TileTensor[in_dtype, input_layout, MutAnyOrigin])
Per-row strided max-|x| reduction into a global FP8 scale.
One block scans one row: threads stride across the hidden dimension, reduce
to a row-wise max absolute value, then thread 0 atomically updates
scale_global with row_max / max_finite[out_dtype].
Args:
- scale_global (
TileTensor): Length-1 FP32TileTensor; must be zero before launch. - input_tensor (
TileTensor): Rank-2 input.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!