Skip to main content

type_promotion

Defines how types can promote in binary operations.

For instance,

var g = Graph(...)
var x = g.scalar[DType.int16](1)
var y = g.scalar[DType.float32](1.)
g.output(x + y) # what dtype does this have?!

We mostly borrow semantics from JAX. We construct our type hierarchy as a lattice, and the type promotion between two types is their join on the lattice.

Here's a full reference table of the promotion semantics:

Typeboolint8int16int32int64uint8uint16uint32uint64indexaddressfloat16bfloat16float32tensor_float32float64
boolboolint8int16int32int64uint8uint16uint32uint64int64int64float16bfloat16float32tensor_float32float64
int8int8int8int16int32int64int16int32int64float16indexaddressfloat16bfloat16float32tensor_float32float64
int16int16int16int16int32int64int16int32int64float16float16float16float16bfloat16float32tensor_float32float64
int32int32int32int32int32int64int32int32int64float16float16float16float16bfloat16float32tensor_float32float64
int64int64int64int64int64int64int64int64int64float16float16float16float16bfloat16float32tensor_float32float64
uint8uint8int16int16int32int64uint8uint16uint32uint64uint64uint64float16bfloat16float32tensor_float32float64
uint16uint16int32int32int32int64uint16uint16uint32uint64uint64uint64float16bfloat16float32tensor_float32float64
uint32uint32int64int64int64int64uint32uint32uint32uint64uint64uint64float16bfloat16float32tensor_float32float64
uint64uint64float16float16float16float16uint64uint64uint64uint64uint64uint64float16bfloat16float32tensor_float32float64
indexuint64float16float16float16float16uint64uint64uint64uint64indexuint64float16bfloat16float32tensor_float32float64
addressuint64float16float16float16float16uint64uint64uint64uint64uint64addressfloat16bfloat16float32tensor_float32float64
float16float16float16float16float16float16float16float16float16float16float16float16float16bfloat16float32tensor_float32float64
bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16bfloat16float32tensor_float32float64
float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32tensor_float32float64
tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32tensor_float32float64
float64float64float64float64float64float64float64float64float64float64float64float64float64float64float64float64float64

Some less common dtypes here that might not be familiar:

  • DType.index is an unsigned word-size integer, like size_t. It promotes up to uint64 regardless of architecture.
  • DType.address is an unsigned word-size integer used for pointer types. It promotes up to uint64 regardless of architecture.
  • DType.bfloat16 is an alternative float representation. Compared to DType.float16 it moves 3 precision bits into the exponent. It promotes up to a DType.float32.
  • DType.tensor_float_32 is an alternative float representation. It is lower precision than a DType.float32. It promotes up to a DType.float64.

address and index aren't common and are more a detail of Mojo's DType implementation, so we never promote other types to them, and they instead behave as uint64.

bfloat16 and tensor_float_32 are placed above their higher-precision counterparts, in other words, float16 can promote to a bfloat16 but not vice-versa. This means that unless you're explicitly using bfloat16 or tensor_float_32 you'll never accidentally promote to them, but will always promote to the higher-precision alternative. If you are using them, be careful to manage your type promotions correctly. Adding a bfloat16 to a float16 will give a bfloat16, for instance. This shouldn't matter in most cases, but for instance you may lose precision in an operation like b16 ** f16, where f16 gets cast to a bfloat16 before the operation and the difference precision is actually significant to the result, resulting in a lower-precision output. You can work around this with explicit casts, eg. b16.cast(DType.float16) ** f16.

implicit_cast_type

implicit_cast_type(lhs: DType, rhs: DType) -> DType

Returns the smallest DType to which both DTypes may be promoted.

Raises: If either DType isn't supported by promotion.

Args:

  • lhs (DType): The first dtype.
  • rhs (DType): The second dtype.

Returns:

The promotion type for a binary operation on two values of types lhs and rhs respectively.

implicit_cast

implicit_cast(lhs: Symbol, rhs: Symbol) -> SymbolTuple

Performs implicit type conversion between operands.

See the max.graph.type_promotion documentation for details on the promotion rules.

Args:

  • lhs (Symbol): The left-hand-side argument.
  • rhs (Symbol): The left-hand-side argument.

Returns:

A SymbolTuple containing lhs and rhs, cast to the promoted dtype if necessary.