Skip to main content
Log in

Mojo function

get_accum_type

get_accum_type[type: DType, *, preferred_accum_type: DType = float32]() -> DType

Returns the recommended type for accumulation operations.

Half precision and float8 types can introduce numerical error if they are used in reduction/accumulation operations. This method returns a higher precision type to use for accumulation if a half precision types is provided, otherwise it returns the original type.

The rules are as follows: - If the type is a float8 type, return a float16 type. - If the type is a bfloat16 precision type, return a float32 type. - If the type is a float16 precision type, return a float32 type if the preferred_accum_type is float32, otherwise return a float16 type. - Otherwise, return the original type.

Parameters:

  • type (DType): The type of some accumulation operation.
  • preferred_accum_type (DType): The preferred type for accumulation.

Returns:

DType.float32 if type is a half-precision float, type otherwise.