Skip to main content
Log in

Mandelbrot in Mojo with Python plots

Not only is Mojo great for writing high-performance code, but it also allows us to leverage the huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it's good at, especially GUIs, without sacrificing performance in critical code. Let's take the classic Mandelbrot set algorithm and implement it in Mojo.

This tutorial shows two aspects of Mojo. First, it shows that Mojo can be used to develop fast programs for irregular applications. It also shows how we can leverage Python for visualizing the results.

Code
import benchmark
from math import iota
from sys import num_physical_cores, simdwidthof
from algorithm import parallelize, vectorize
from complex import ComplexFloat64, ComplexSIMD
from python import Python

alias float_type = DType.float32
alias int_type = DType.int32
alias simd_width = 2 * simdwidthof[float_type]()
alias unit = benchmark.Unit.ms
import benchmark
from math import iota
from sys import num_physical_cores, simdwidthof
from algorithm import parallelize, vectorize
from complex import ComplexFloat64, ComplexSIMD
from python import Python

alias float_type = DType.float32
alias int_type = DType.int32
alias simd_width = 2 * simdwidthof[float_type]()
alias unit = benchmark.Unit.ms

First set some parameters, you can try changing these to see different results:

alias width = 960
alias height = 960
alias MAX_ITERS = 200

alias min_x = -2.0
alias max_x = 0.6
alias min_y = -1.5
alias max_y = 1.5
alias width = 960
alias height = 960
alias MAX_ITERS = 200

alias min_x = -2.0
alias max_x = 0.6
alias min_y = -1.5
alias max_y = 1.5

Here we define a simple Matrix struct:

@value
struct Matrix[type: DType, rows: Int, cols: Int]:
var data: UnsafePointer[Scalar[type]]

fn __init__(out self):
self.data = UnsafePointer[Scalar[type]].alloc(rows * cols)

fn __getitem__(self, row: Int, col: Int) -> Scalar[type]:
return self.data.load(row * cols + col)

fn store[width: Int = 1](self, row: Int, col: Int, val: SIMD[type, width]):
self.data.store(row * cols + col, val)
@value
struct Matrix[type: DType, rows: Int, cols: Int]:
var data: UnsafePointer[Scalar[type]]

fn __init__(out self):
self.data = UnsafePointer[Scalar[type]].alloc(rows * cols)

fn __getitem__(self, row: Int, col: Int) -> Scalar[type]:
return self.data.load(row * cols + col)

fn store[width: Int = 1](self, row: Int, col: Int, val: SIMD[type, width]):
self.data.store(row * cols + col, val)

The core Mandelbrot algorithm involves computing an iterative complex function for each pixel until it "escapes" the complex circle of radius 2, counting the number of iterations to escape:

zi+1=zi2+cz_{i+1} = z_i^2 + c

# Compute the number of steps to escape.
def mandelbrot_kernel(c: ComplexFloat64) -> Int:
z = c
for i in range(MAX_ITERS):
z = z * z + c
if z.squared_norm() > 4:
return i
return MAX_ITERS


def compute_mandelbrot() -> Matrix[float_type, height, width]:
# create a matrix. Each element of the matrix corresponds to a pixel
matrix = Matrix[float_type, height, width]()

dx = (max_x - min_x) / width
dy = (max_y - min_y) / height

y = min_y
for row in range(height):
x = min_x
for col in range(width):
matrix.store(row, col, mandelbrot_kernel(ComplexFloat64(x, y)))
x += dx
y += dy
return matrix
# Compute the number of steps to escape.
def mandelbrot_kernel(c: ComplexFloat64) -> Int:
z = c
for i in range(MAX_ITERS):
z = z * z + c
if z.squared_norm() > 4:
return i
return MAX_ITERS


def compute_mandelbrot() -> Matrix[float_type, height, width]:
# create a matrix. Each element of the matrix corresponds to a pixel
matrix = Matrix[float_type, height, width]()

dx = (max_x - min_x) / width
dy = (max_y - min_y) / height

y = min_y
for row in range(height):
x = min_x
for col in range(width):
matrix.store(row, col, mandelbrot_kernel(ComplexFloat64(x, y)))
x += dx
y += dy
return matrix

Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python's matplotlib right from Mojo!

First install the required libraries:

%%python
from importlib.util import find_spec
import shutil
import subprocess

fix = """
-------------------------------------------------------------------------
fix following the steps here:
https://github.com/modularml/mojo/issues/1085#issuecomment-1771403719
-------------------------------------------------------------------------
"""

def install_if_missing(name: str):
if find_spec(name):
return
print("missing", name)
print(f"{name} not found, installing...")
try:
if shutil.which('python3'): python = "python3"
elif shutil.which('python'): python = "python"
else: raise ("python not on path" + fix)
subprocess.check_call([python, "-m", "pip", "install", name])
except:
raise ImportError(f"{name} not found" + fix)

install_if_missing("numpy")
install_if_missing("matplotlib")
%%python
from importlib.util import find_spec
import shutil
import subprocess

fix = """
-------------------------------------------------------------------------
fix following the steps here:
https://github.com/modularml/mojo/issues/1085#issuecomment-1771403719
-------------------------------------------------------------------------
"""

def install_if_missing(name: str):
if find_spec(name):
return
print("missing", name)
print(f"{name} not found, installing...")
try:
if shutil.which('python3'): python = "python3"
elif shutil.which('python'): python = "python"
else: raise ("python not on path" + fix)
subprocess.check_call([python, "-m", "pip", "install", name])
except:
raise ImportError(f"{name} not found" + fix)

install_if_missing("numpy")
install_if_missing("matplotlib")
def show_plot[type: DType](matrix: Matrix[type, height, width]):
alias scale = 10
alias dpi = 64

np = Python.import_module("numpy")
plt = Python.import_module("matplotlib.pyplot")
colors = Python.import_module("matplotlib.colors")

numpy_array = np.zeros((height, width), np.float64)

for row in range(height):
for col in range(width):
numpy_array.itemset((row, col), matrix[row, col])

fig = plt.figure(1, [scale, scale * height // width], dpi)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1)
light = colors.LightSource(315, 10, 0, 1, 1, 0)

image = light.shade(numpy_array, plt.cm.hot, colors.PowerNorm(0.3), "hsv", 0, 0, 1.5)
plt.imshow(image)
plt.axis("off")
plt.show()

show_plot(compute_mandelbrot())
def show_plot[type: DType](matrix: Matrix[type, height, width]):
alias scale = 10
alias dpi = 64

np = Python.import_module("numpy")
plt = Python.import_module("matplotlib.pyplot")
colors = Python.import_module("matplotlib.colors")

numpy_array = np.zeros((height, width), np.float64)

for row in range(height):
for col in range(width):
numpy_array.itemset((row, col), matrix[row, col])

fig = plt.figure(1, [scale, scale * height // width], dpi)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1)
light = colors.LightSource(315, 10, 0, 1, 1, 0)

image = light.shade(numpy_array, plt.cm.hot, colors.PowerNorm(0.3), "hsv", 0, 0, 1.5)
plt.imshow(image)
plt.axis("off")
plt.show()

show_plot(compute_mandelbrot())

Vectorizing Mandelbrot

We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo's access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the vectorize higher order generator.

We start by defining our main iteration loop in a vectorized fashion

fn mandelbrot_kernel_SIMD[
simd_width: Int
](c: ComplexSIMD[float_type, simd_width]) -> SIMD[int_type, simd_width]:
"""A vectorized implementation of the inner mandelbrot computation."""
var cx = c.re
var cy = c.im
var x = SIMD[float_type, simd_width](0)
var y = SIMD[float_type, simd_width](0)
var y2 = SIMD[float_type, simd_width](0)
var iters = SIMD[int_type, simd_width](0)

var t: SIMD[DType.bool, simd_width] = True
for _ in range(MAX_ITERS):
if not any(t):
break
y2 = y * y
y = x.fma(y + y, cy)
t = x.fma(x, y2) <= 4
x = x.fma(x, cx - y2)
iters = t.select(iters + 1, iters)
return iters
fn mandelbrot_kernel_SIMD[
simd_width: Int
](c: ComplexSIMD[float_type, simd_width]) -> SIMD[int_type, simd_width]:
"""A vectorized implementation of the inner mandelbrot computation."""
var cx = c.re
var cy = c.im
var x = SIMD[float_type, simd_width](0)
var y = SIMD[float_type, simd_width](0)
var y2 = SIMD[float_type, simd_width](0)
var iters = SIMD[int_type, simd_width](0)

var t: SIMD[DType.bool, simd_width] = True
for _ in range(MAX_ITERS):
if not any(t):
break
y2 = y * y
y = x.fma(y + y, cy)
t = x.fma(x, y2) <= 4
x = x.fma(x, cx - y2)
iters = t.select(iters + 1, iters)
return iters

The above function is parameterized on the simd_width and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the vectorize generator to make this a simple function call. The benchmark can run in parallel or just vectorized.

fn run_mandelbrot(parallel: Bool) raises -> Float64:
var matrix = Matrix[int_type, height, width]()

@parameter
fn worker(row: Int):
alias scale_x = (max_x - min_x) / width
alias scale_y = (max_y - min_y) / height

@parameter
fn compute_vector[simd_width: Int](col: Int):
"""Each time we operate on a `simd_width` vector of pixels."""
var cx = min_x + (col + iota[float_type, simd_width]()) * scale_x
var cy = min_y + row * SIMD[float_type, simd_width](scale_y)
var c = ComplexSIMD[float_type, simd_width](cx, cy)
matrix.store(row, col, mandelbrot_kernel_SIMD[simd_width](c))

# Vectorize the call to compute_vector where call gets a chunk of pixels.
vectorize[compute_vector, simd_width](width)

@parameter
fn bench():
for row in range(height):
worker(row)

@parameter
fn bench_parallel():
parallelize[worker](height, height)

var time: Float64 = 0
if parallel:
time = benchmark.run[bench_parallel](max_runtime_secs=0.5).mean(unit)
else:
time = benchmark.run[bench](max_runtime_secs=0.5).mean(unit)

show_plot(matrix)
matrix.data.free()
return time

vectorized = run_mandelbrot(parallel=False)
print("Vectorized:", vectorized, unit)
fn run_mandelbrot(parallel: Bool) raises -> Float64:
var matrix = Matrix[int_type, height, width]()

@parameter
fn worker(row: Int):
alias scale_x = (max_x - min_x) / width
alias scale_y = (max_y - min_y) / height

@parameter
fn compute_vector[simd_width: Int](col: Int):
"""Each time we operate on a `simd_width` vector of pixels."""
var cx = min_x + (col + iota[float_type, simd_width]()) * scale_x
var cy = min_y + row * SIMD[float_type, simd_width](scale_y)
var c = ComplexSIMD[float_type, simd_width](cx, cy)
matrix.store(row, col, mandelbrot_kernel_SIMD[simd_width](c))

# Vectorize the call to compute_vector where call gets a chunk of pixels.
vectorize[compute_vector, simd_width](width)

@parameter
fn bench():
for row in range(height):
worker(row)

@parameter
fn bench_parallel():
parallelize[worker](height, height)

var time: Float64 = 0
if parallel:
time = benchmark.run[bench_parallel](max_runtime_secs=0.5).mean(unit)
else:
time = benchmark.run[bench](max_runtime_secs=0.5).mean(unit)

show_plot(matrix)
matrix.data.free()
return time

vectorized = run_mandelbrot(parallel=False)
print("Vectorized:", vectorized, unit)

Parallelizing Mandelbrot

While the vectorized implementation above is efficient, we can get better performance by parallelizing on the cols. This again is simple in Mojo using the parallelize higher order function:

parallelized = run_mandelbrot(parallel=True)
print("Parallelized:", parallelized, unit)
parallelized = run_mandelbrot(parallel=True)
print("Parallelized:", parallelized, unit)

Benchmarking

In this section we compare the vectorized speed to the parallelized speed

print("Number of physical cores:", num_physical_cores())
print("Vectorized:", vectorized, "ms")
print("Parallelized:", parallelized, "ms")
print("Parallel speedup:", vectorized / parallelized)
print("Number of physical cores:", num_physical_cores())
print("Vectorized:", vectorized, "ms")
print("Parallelized:", parallelized, "ms")
print("Parallel speedup:", vectorized / parallelized)

Was this page helpful?