Add operator support to custom types
Each Mojo operator maps to a set of dunder methods you can add to your struct implementation. These methods let you use operator syntax instead of calling methods directly.
Knowing your operators and their related methods opens the full suite of operator syntax to your custom structs.
Forward, reverse, and in-place methods
Each binary operator uses up to three method forms. For example,
consider the addition a + b:
- Forward: Mojo tries
a.__add__(b)first. - Reverse: If the forward method doesn't exist or can't handle
b's type, Mojo falls back tob.__radd__(a). - In-place: For
a += b, Mojo callsa.__iadd__(b).
Reversed methods exist for mixed-type expressions where the left operand doesn't know about the right operand's type:
a + 5 # calls a.__add__(5)
5 + a # Int doesn't know your type, falls back
# to a.__radd__(5)
a += 5 # calls a.__iadd__(5)Unary operators
A unary operator returns the original value if unchanged, or a new value
representing the result. For example, -x uses the unary negation operator:
@fieldwise_init
struct MyInt:
var value: Int
def __neg__(self) -> Self:
return Self(-self.value)If x is a MyInt, then -x returns a new instance with its value
field negated.
Comparison operators and traits
Operators do not require that you conform your types to traits. However, there are benefits to doing so.
The Comparable trait provides defaults for <=, >, and >=. You just
implement __lt__() and __eq__().
Similarly, the Equatable trait provides defaults for __eq__() and
__ne__() when all fields are Equatable.
For types without a natural ordering (like complex numbers), only implement
Equatable, and not Comparable.
Subscript operators
Implement __getitem__() for reads and __setitem__() for writes. Both
subscripting methods accept variadic arguments for multi-dimensional
indexing.
For a simple one-dimensional collection, you unlock subscripting with a simple index:
struct MySeq[T: Copyable]:
def __getitem__(self, idx: Int) -> T:
...
def __setitem__(mut self, idx: Int, value: T):
...For multi-dimensional collections, make use of variadics or multiple index arguments:
struct Grid[T: Copyable]:
# Fixed two dimensions
def __getitem__(self, x: Int, y: Int) -> T:
...
# Arbitrary dimensions
def __getitem__(self, *indices: Int) -> T:
...Custom subscripts can support slicing as well as indices, such as
obj[1:5]. Implement __getitem__() with a
Slice parameter instead of
Int.
Each Slice has three optional fields: start, end, and step. You
normalize these by calling indices(). Pass your type's size. This returns
a triplet of values representing the span adjusted to your extent,
resolving omitted values or negative indices into non-negative positions:
struct MySeq[T: Copyable]:
var size: Int
def __getitem__(self, span: Slice) -> Self:
var start: Int
var end: Int
var step: Int
start, end, step = span.indices(self.size)
...Walkthrough: Build a Complex type
The next sections incrementally build a Complex struct.
This example demonstrates every category of operator implementation.
In this walk-through, you'll work with unary operators, binary operators with same-type and mixed-type operands, reversed methods, in-place assignment, equality comparison, and subscript access.
Create the base type
A complex number holds real and imaginary parts, stored in the
real re and imaginary im fields:
from std.math import sqrt
@fieldwise_init
struct Complex(
Equatable,
Writable,
TrivialRegisterPassable,
):
var re: Float64
var im: Float64- Conforming to
TrivialRegisterPassablegives you value semantics without needing to write special lifecycle methods. Equatablelets you compare two instances, andWritableproduces output forprint()statements.
Convenience initializer
Adding a convenience initializer lets you create instances using only the real part of your number:
def __init__(out self, re: Float64):
self.re = re
self.im = 0.0Make your type printable
Implementing Writable lets you use print() and String() directly.
This custom implementation provides parentheses and separate real and
imaginary output.
# Struct method
def write_to(self, mut writer: Some[Writer]):
writer.write("(", self.re)
if self.im < 0:
writer.write(" - ", -self.im)
else:
writer.write(" + ", self.im)
writer.write("i)")
...
c = Complex(3.14, -2.72)
print(c) # (3.14 - 2.72i)Add unary operator support
+c returns the value unchanged. -c negates both
components:
# methods
def __pos__(self) -> Self:
return self
def __neg__(self) -> Self:
return Self(-self.re, -self.im)
...
c = Complex(-1.2, 6.5)
print(+c) # (-1.2 + 6.5i)
print(-c) # (1.2 - 6.5i)Support binary arithmetic
Add addition, subtraction, multiplication, and division between
two Complex values with dunders. Each form returns a new Complex
instance:
def __add__(self, rhs: Self) -> Self:
return Self(self.re + rhs.re, self.im + rhs.im)
def __sub__(self, rhs: Self) -> Self:
return Self(self.re - rhs.re, self.im - rhs.im)
def __mul__(self, rhs: Self) -> Self:
return Self(
self.re * rhs.re - self.im * rhs.im,
self.re * rhs.im + self.im * rhs.re,
)
def __truediv__(self, rhs: Self) -> Self:
denom = rhs.squared_norm()
return Self(
(self.re * rhs.re + self.im * rhs.im) / denom,
(self.im * rhs.re - self.re * rhs.im) / denom,
)
def squared_norm(self) -> Float64:
return self.re * self.re + self.im * self.im
def norm(self) -> Float64:
return sqrt(self.squared_norm())c1 = Complex(-1.2, 6.5)
c2 = Complex(3.14, -2.72)
print(c1 + c2) # (1.94 + 3.78i)
print(c1 * c2) # (13.91 + 23.67i)Add mixed-type arithmetic with reversed methods
To support expressions like 2.5 + c where Float64 is on
the left, you need both overloaded forward methods and
reversed methods. Without __radd__(), 2.5 + c would fail
because Float64 doesn't know about Complex:
# Forward: Complex + Float64
def __add__(self, rhs: Float64) -> Self:
return Self(self.re + rhs, self.im)
# Reversed: Float64 + Complex
def __radd__(self, lhs: Float64) -> Self:
return Self(self.re + lhs, self.im)
def __sub__(self, rhs: Float64) -> Self:
return Self(self.re - rhs, self.im)
def __rsub__(self, lhs: Float64) -> Self:
return Self(lhs - self.re, -self.im)
def __mul__(self, rhs: Float64) -> Self:
return Self(self.re * rhs, self.im * rhs)
def __rmul__(self, lhs: Float64) -> Self:
return Self(lhs * self.re, lhs * self.im)
def __truediv__(self, rhs: Float64) -> Self:
return Self(self.re / rhs, self.im / rhs)
def __rtruediv__(self, lhs: Float64) -> Self:
denom = self.squared_norm()
return Self(
(lhs * self.re) / denom,
(-lhs * self.im) / denom,
)Now both orderings work:
c = Complex(-1.2, 6.5)
print(c + 2.5) # (1.3 + 6.5i)
print(2.5 + c) # (1.3 + 6.5i)
print(2.5 * c) # (-3.0 + 16.25i)Allow in-place assignment
In-place methods modify self directly instead of returning
a new value. You can overload for both Complex and
Float64 operands:
def __iadd__(mut self, rhs: Self):
self.re += rhs.re
self.im += rhs.im
def __iadd__(mut self, rhs: Float64):
self.re += rhs
def __isub__(mut self, rhs: Self):
self.re -= rhs.re
self.im -= rhs.im
def __isub__(mut self, rhs: Float64):
self.re -= rhs
def __imul__(mut self, rhs: Self):
var new_re = self.re * rhs.re - self.im * rhs.im
var new_im = self.re * rhs.im + self.im * rhs.re
self.re = new_re
self.im = new_im
def __imul__(mut self, rhs: Float64):
self.re *= rhs
self.im *= rhs
def __itruediv__(mut self, rhs: Self):
var denom = rhs.squared_norm()
var new_re = (self.re * rhs.re + self.im * rhs.im) / denom
var new_im = (self.im * rhs.re - self.re * rhs.im) / denom
self.re = new_re
self.im = new_im
def __itruediv__(mut self, rhs: Float64):
self.re /= rhs
self.im /= rhs
...
c = Complex(-1.0, -1.0)
c += Complex(0.5, -0.5)
print(c) # (-0.5 - 1.5i)
c += 2.75
print(c) # (2.25 - 1.5i)
c *= 0.75
print(c) # (1.6875 - 1.125i)
c /= 2.0
print(c) # (0.84375 - 0.5625i)Support type equality checks
Complex numbers have no natural ordering, so you should
implement Equatable (not Comparable). This gives you
== and != without implying that one complex number is
"less than" another:
def __eq__(self, other: Self) -> Bool:
return (self.re == other.re and
self.im == other.im)
# Conforming to Equatable provides this automatically
def __ne__(self, other: Self) -> Bool:
return not (self == other)c1 = Complex(-1.2, 6.5)
c2 = Complex(-1.2, 6.5)
c3 = Complex(3.14, -2.72)
print(c1 == c2) # True
print(c1 != c3) # TrueUnlock subscript access
The get and set item dunders allow you to index content within your type. For this example, the real part of the complex number is index 0, and index 1 returns the imaginary component:
def __getitem__(self, idx: Int) raises -> Float64:
if idx == 0: return self.re
if idx == 1: return self.im
raise "index out of bounds"
def __setitem__(
mut self, idx: Int, value: Float64
) raises:
if idx == 0: self.re = value
elif idx == 1: self.im = value
else: raise "index out of bounds"
...
c = Complex(3.14)
print(c[0], c[1]) # 3.14 0.0
c[1] = 42.0
print(c) # (3.14 + 42.0i)Every operator, one walkthrough
This example walked you through every Mojo operator from simple arithmetic to comparisons to subscripting.
Implementing the right dunders and/or conforming to the right traits enables you to use operator syntax for nearly any custom type.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!