Skip to main content

Add operator support to custom types

Each Mojo operator maps to a set of dunder methods you can add to your struct implementation. These methods let you use operator syntax instead of calling methods directly.

Knowing your operators and their related methods opens the full suite of operator syntax to your custom structs.

Forward, reverse, and in-place methods

Each binary operator uses up to three method forms. For example, consider the addition a + b:

  • Forward: Mojo tries a.__add__(b) first.
  • Reverse: If the forward method doesn't exist or can't handle b's type, Mojo falls back to b.__radd__(a).
  • In-place: For a += b, Mojo calls a.__iadd__(b).

Reversed methods exist for mixed-type expressions where the left operand doesn't know about the right operand's type:

a + 5    # calls a.__add__(5)
5 + a    # Int doesn't know your type, falls back
         # to a.__radd__(5)
a += 5   # calls a.__iadd__(5)

Unary operators

A unary operator returns the original value if unchanged, or a new value representing the result. For example, -x uses the unary negation operator:

@fieldwise_init
struct MyInt:
    var value: Int

    def __neg__(self) -> Self:
        return Self(-self.value)

If x is a MyInt, then -x returns a new instance with its value field negated.

Comparison operators and traits

Operators do not require that you conform your types to traits. However, there are benefits to doing so.

The Comparable trait provides defaults for <=, >, and >=. You just implement __lt__() and __eq__().

Similarly, the Equatable trait provides defaults for __eq__() and __ne__() when all fields are Equatable.

For types without a natural ordering (like complex numbers), only implement Equatable, and not Comparable.

Subscript operators

Implement __getitem__() for reads and __setitem__() for writes. Both subscripting methods accept variadic arguments for multi-dimensional indexing.

For a simple one-dimensional collection, you unlock subscripting with a simple index:

struct MySeq[T: Copyable]:
    def __getitem__(self, idx: Int) -> T:
        ...
    def __setitem__(mut self, idx: Int, value: T):
        ...

For multi-dimensional collections, make use of variadics or multiple index arguments:

struct Grid[T: Copyable]:
    # Fixed two dimensions
    def __getitem__(self, x: Int, y: Int) -> T:
        ...
    # Arbitrary dimensions
    def __getitem__(self, *indices: Int) -> T:
        ...

Custom subscripts can support slicing as well as indices, such as obj[1:5]. Implement __getitem__() with a Slice parameter instead of Int.

Each Slice has three optional fields: start, end, and step. You normalize these by calling indices(). Pass your type's size. This returns a triplet of values representing the span adjusted to your extent, resolving omitted values or negative indices into non-negative positions:

struct MySeq[T: Copyable]:
    var size: Int

    def __getitem__(self, span: Slice) -> Self:
        var start: Int
        var end: Int
        var step: Int
        start, end, step = span.indices(self.size)
        ...

Walkthrough: Build a Complex type

The next sections incrementally build a Complex struct. This example demonstrates every category of operator implementation.

In this walk-through, you'll work with unary operators, binary operators with same-type and mixed-type operands, reversed methods, in-place assignment, equality comparison, and subscript access.

Create the base type

A complex number holds real and imaginary parts, stored in the real re and imaginary im fields:

from std.math import sqrt

@fieldwise_init
struct Complex(
    Equatable,
    Writable,
    TrivialRegisterPassable,
):
    var re: Float64
    var im: Float64
  • Conforming to TrivialRegisterPassable gives you value semantics without needing to write special lifecycle methods.
  • Equatable lets you compare two instances, and Writable produces output for print() statements.

Convenience initializer

Adding a convenience initializer lets you create instances using only the real part of your number:

def __init__(out self, re: Float64):
    self.re = re
    self.im = 0.0

Make your type printable

Implementing Writable lets you use print() and String() directly. This custom implementation provides parentheses and separate real and imaginary output.

# Struct method
def write_to(self, mut writer: Some[Writer]):
    writer.write("(", self.re)
    if self.im < 0:
        writer.write(" - ", -self.im)
    else:
        writer.write(" + ", self.im)
    writer.write("i)")

...

c = Complex(3.14, -2.72)
print(c)  # (3.14 - 2.72i)

Add unary operator support

+c returns the value unchanged. -c negates both components:

# methods
def __pos__(self) -> Self:
    return self

def __neg__(self) -> Self:
    return Self(-self.re, -self.im)

...

c = Complex(-1.2, 6.5)
print(+c)  # (-1.2 + 6.5i)
print(-c)  # (1.2 - 6.5i)

Support binary arithmetic

Add addition, subtraction, multiplication, and division between two Complex values with dunders. Each form returns a new Complex instance:

    def __add__(self, rhs: Self) -> Self:
        return Self(self.re + rhs.re, self.im + rhs.im)

    def __sub__(self, rhs: Self) -> Self:
        return Self(self.re - rhs.re, self.im - rhs.im)

    def __mul__(self, rhs: Self) -> Self:
        return Self(
            self.re * rhs.re - self.im * rhs.im,
            self.re * rhs.im + self.im * rhs.re,
        )

    def __truediv__(self, rhs: Self) -> Self:
        denom = rhs.squared_norm()
        return Self(
            (self.re * rhs.re + self.im * rhs.im) / denom,
            (self.im * rhs.re - self.re * rhs.im) / denom,
        )

    def squared_norm(self) -> Float64:
        return self.re * self.re + self.im * self.im

    def norm(self) -> Float64:
        return sqrt(self.squared_norm())
c1 = Complex(-1.2, 6.5)
c2 = Complex(3.14, -2.72)
print(c1 + c2)  # (1.94 + 3.78i)
print(c1 * c2)  # (13.91 + 23.67i)

Add mixed-type arithmetic with reversed methods

To support expressions like 2.5 + c where Float64 is on the left, you need both overloaded forward methods and reversed methods. Without __radd__(), 2.5 + c would fail because Float64 doesn't know about Complex:

    # Forward: Complex + Float64
    def __add__(self, rhs: Float64) -> Self:
        return Self(self.re + rhs, self.im)

    # Reversed: Float64 + Complex
    def __radd__(self, lhs: Float64) -> Self:
        return Self(self.re + lhs, self.im)

    def __sub__(self, rhs: Float64) -> Self:
        return Self(self.re - rhs, self.im)

    def __rsub__(self, lhs: Float64) -> Self:
        return Self(lhs - self.re, -self.im)

    def __mul__(self, rhs: Float64) -> Self:
        return Self(self.re * rhs, self.im * rhs)

    def __rmul__(self, lhs: Float64) -> Self:
        return Self(lhs * self.re, lhs * self.im)

    def __truediv__(self, rhs: Float64) -> Self:
        return Self(self.re / rhs, self.im / rhs)

    def __rtruediv__(self, lhs: Float64) -> Self:
        denom = self.squared_norm()
        return Self(
            (lhs * self.re) / denom,
            (-lhs * self.im) / denom,
        )

Now both orderings work:

c = Complex(-1.2, 6.5)
print(c + 2.5)    # (1.3 + 6.5i)
print(2.5 + c)    # (1.3 + 6.5i)
print(2.5 * c)    # (-3.0 + 16.25i)

Allow in-place assignment

In-place methods modify self directly instead of returning a new value. You can overload for both Complex and Float64 operands:

    def __iadd__(mut self, rhs: Self):
        self.re += rhs.re
        self.im += rhs.im

    def __iadd__(mut self, rhs: Float64):
        self.re += rhs

    def __isub__(mut self, rhs: Self):
        self.re -= rhs.re
        self.im -= rhs.im

    def __isub__(mut self, rhs: Float64):
        self.re -= rhs

    def __imul__(mut self, rhs: Self):
        var new_re = self.re * rhs.re - self.im * rhs.im
        var new_im = self.re * rhs.im + self.im * rhs.re
        self.re = new_re
        self.im = new_im

    def __imul__(mut self, rhs: Float64):
        self.re *= rhs
        self.im *= rhs

    def __itruediv__(mut self, rhs: Self):
        var denom = rhs.squared_norm()
        var new_re = (self.re * rhs.re + self.im * rhs.im) / denom
        var new_im = (self.im * rhs.re - self.re * rhs.im) / denom
        self.re = new_re
        self.im = new_im

    def __itruediv__(mut self, rhs: Float64):
        self.re /= rhs
        self.im /= rhs

...

c = Complex(-1.0, -1.0)
c += Complex(0.5, -0.5)
print(c)       # (-0.5 - 1.5i)
c += 2.75
print(c)       # (2.25 - 1.5i)
c *= 0.75
print(c)       # (1.6875 - 1.125i)
c /= 2.0
print(c)       # (0.84375 - 0.5625i)

Support type equality checks

Complex numbers have no natural ordering, so you should implement Equatable (not Comparable). This gives you == and != without implying that one complex number is "less than" another:

    def __eq__(self, other: Self) -> Bool:
        return (self.re == other.re and
                self.im == other.im)

    # Conforming to Equatable provides this automatically
    def __ne__(self, other: Self) -> Bool:
        return not (self == other)
c1 = Complex(-1.2, 6.5)
c2 = Complex(-1.2, 6.5)
c3 = Complex(3.14, -2.72)
print(c1 == c2)  # True
print(c1 != c3)  # True

Unlock subscript access

The get and set item dunders allow you to index content within your type. For this example, the real part of the complex number is index 0, and index 1 returns the imaginary component:

    def __getitem__(self, idx: Int) raises -> Float64:
        if idx == 0: return self.re
        if idx == 1: return self.im
        raise "index out of bounds"

    def __setitem__(
        mut self, idx: Int, value: Float64
    ) raises:
        if idx == 0: self.re = value
        elif idx == 1: self.im = value
        else: raise "index out of bounds"

...

c = Complex(3.14)
print(c[0], c[1])  # 3.14 0.0
c[1] = 42.0
print(c)            # (3.14 + 42.0i)

Every operator, one walkthrough

This example walked you through every Mojo operator from simple arithmetic to comparisons to subscripting.

Implementing the right dunders and/or conforming to the right traits enables you to use operator syntax for nearly any custom type.

Was this page helpful?