Skip to main content
Log in

Mojo module

elementwise

Ops that perform element-wise computations/comparisons.

Operations in this module are split into either unary or binary operations.

Unary Operations

Elementwise-unary-operations all have the following properties:

  • They operate on a single symbolic tensor value of any shape
  • Their output is a single symbolic tensor value with the same shape as their input
  • The computation they represent will be itemwise-independent, in other words the output value in any position of the output tensor at computation time will depend only on the input value at that same position, and no others.

Binary Operations

Elementwise-binary-operations all have the following properties:

  • They operate on two symbolic tensor values, a left value and a right value.
  • The input tensor types must be compatible according to the broadcasting rules.
  • If the input tensor types have different element types, they will each be promoted to the same dtype according to the dtype promotion rules before executing the operation. This may involve a cast that changes the representation (including precision) of the data values.
  • Their output is a single symbolic tensor value with
    • dtype depending on the op and the promoted dtype, ie. promote(lhs, rhs).
    • shape equal to the result of broadcast(lhs, rhs)
  • The computation they represent will be itemwise-independent, in other words after broadcasting the input values to the same shape, the output value in any position of the output tensor at computation time will depend only on the input position at the two broadcast input values at that same position, and no others.

DType Promotion Rules

The Graph API splits dtype promotion into two pieces: bit width and category. Bit width is simply the number of bits that are needed to represent a dtype. Category is an order hierarchy: bool < unsigned int < signed int < float.

A promotion candidate is calculated between two dtypes (a and b) as: (max(category(a), category(b)), max(bitwidth(a), bitwidth(b))).

An exception will be raised if a either input dtype might contain a value that is unrepresentable by the promotion candidate (e.g u32 -> i32 or i32 -> f32).

An exception will be raised if the input has the same bit width but a different format than the promotion candidate (e.g. f16 -> bf16 or f32 -> tf32).

If no exception is raised, the promotion candidate is accepted. All inputs will be cast to the promotion candidate before the underlying operation is run.

Broadcasting Rules

Given two input tensor shapes, broadcasting works as following:

  1. Prepend static 1 dimensions onto the tensor with lower rank to make it so that both tensors have the same rank.
  2. If a dimension is a static 1 dimension, it will broadcast to the size of the dimension in the other tensor.
  3. All other dimensions will be asserted to be equivalent. If they are not, an exception will be raised.

Functions

  • abs: Computes the elementwise absolute value of a symbolic tensor.
  • add: Adds two symbolic tensors.
  • cos: Computes the elementwise cosine of a symbolic tensor.
  • div: Divides two symbolic tensors.
  • equal: Computes the elementwise equality comparison between two symbolic tensors.
  • erf: Computes the elementwise error function of a symbolic tensor.
  • exp: Computes the elementwise exp function of a symbolic tensor.
  • floor: Computes the elementwise floor of a symbolic tensor.
  • gelu: Computes the elementwise gelu function of a symbolic tensor.
  • greater: Computes the elementwise greater than comparison between two symbolic tensors.
  • greater_equal: Computes the elementwise greater-or-equal comparison between two symbolic tensors.
  • is_inf: Computes the elementwise is_inf of a symbolic tensor.
  • is_nan: Computes the elementwise is_nan of a symbolic tensor.
  • log: Computes the elementwise natural logarithm of a symbolic tensor.
  • log1p: Computes the elementwise logarithm of 1 plus a symbolic tensor.
  • logsoftmax: Computes the elementwise logsoftmax of a symbolic tensor.
  • max: Computes the elementwise maximum of two symbolic tensors.
  • min: Computes the elementwise minimum of two symbolic tensors.
  • mod: Computes the elementwise maximum of two symbolic tensors.
  • mul: Computes the elementwise multiplication of two symbolic tensors.
  • not_equal: Computes the elementwise inequality comparison between two symbolic tensors.
  • pow: Computes the elementwise exponentiation of two symbolic tensors.
  • relu: Computes the elementwise relu of a symbolic tensor.
  • round: Computes the elementwise round of a symbolic tensor.
  • roundeven: Computes the elementwise roundeven of a symbolic tensor.
  • rsqrt: Computes the elementwise inverse-square-root of a symbolic tensor.
  • sigmoid: Computes the elementwise sigmoid of a symbolic tensor.
  • silu: Computes the elementwise silu of a symbolic tensor.
  • sin: Computes the elementwise sine of a symbolic tensor.
  • softmax: Computes the elementwise softmax of a symbolic tensor.
  • sqrt: Computes the elementwise sqrt of a symbolic tensor.
  • sub: Computes the elementwise subtraction of two symbolic tensors.
  • tanh: Computes the elementwise tanh of a symbolic tensor.
  • trunc: Computes the elementwise truncation of a symbolic tensor.