Learn how to write high-performance Mojo code and import Python packages.
Mandelbrot in Mojo with Python plots
Not only is Mojo great for writing high-performance code, but it also allows us to leverage the huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it’s good at, especially GUIs, without sacrificing performance in critical code. Let’s take the classic Mandelbrot set algorithm and implement it in Mojo.
This tutorial shows two aspects of Mojo. First, it shows that Mojo can be used to develop fast programs for irregular applications. It also shows how we can leverage Python for visualizing the results.
Import utilities and define Matrix (click to show/hide)
Then we can write the core Mandelbrot algorithm, which involves computing an iterative complex function for each pixel until it “escapes” the complex circle of radius 2, counting the number of iterations to escape.
\[z_{i+1} = z_i^2 + c\]
alias xmin: F32 =-2alias xmax: F32 =0.6alias xn =960alias ymin: F32 =-1.5alias ymax: F32 =1.5alias yn =768alias MAX_ITERS =200# Compute the number of steps to escape.def mandelbrot_kernel(c: Complex) -> Int: z = cfor i inrange(MAX_ITERS): z = z * z + cif z.norm() >4:return ireturn MAX_ITERSdef compute_mandelbrot() -> Matrix:# create a matrix. Each element of the matrix corresponds to a pixel result = Matrix(xn, yn) dx = (xmax - xmin) / xn dy = (ymax - ymin) / yn y = yminfor j inrange(yn): x = xminfor i inrange(xn): result[i, j] = mandelbrot_kernel(Complex(x, y)) x += dx y += dyreturn result
Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python’s matplotlib right from Mojo!
We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo’s access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the vectorize higher order generator.
We start by defining our main iteration loop in a vectorized fashion
fn mandelbrot_kernel_simd[simd_width:Int](c: ComplexGenericSIMD[DType.f32, simd_width]) -> SIMD[DType.f32, simd_width]:var z = cvar nv = SIMD[DType.f32, simd_width](0)var escape_mask = SIMD[DType.bool, simd_width](0)for i inrange(MAX_ITERS):if escape_mask: # All the elements have escaped, so exit.break z = z*z + c# Only update elements that haven't escaped yet escape_mask = escape_mask.select(escape_mask, z.norm() >4) nv = escape_mask.select(nv, nv +1)return nv
The above function is parameterized on the simd_width and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the vectorize generator to make this a simple function call.
from Functional import vectorizefrom Math import iotafrom TargetInfo import dtype_simd_widthdef compute_mandelbrot_simd() -> Matrix:# create a matrix. Each element of the matrix corresponds to a pixelvar result = Matrix(xn, yn)let dx = (xmax - xmin) / xnlet dy = (ymax - ymin) / ynvar y = yminalias simd_width = dtype_simd_width[DType.f32]()for row inrange(yn):var x = xmin@parameterfn _process_simd_element[simd_width:Int](col: Int):let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, y) result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c)) x += simd_width*dx vectorize[simd_width, _process_simd_element](xn) y += dyreturn resultmake_plot(compute_mandelbrot_simd())print("finished")
finished
Parallelizing Mandelbrot
While the vectorized implementation above is efficient, we can get better performance by parallelizing on the rows. This again is simple in Mojo using the parallelize higher order function. Only the function that performs the invocation needs to change.
from Functional import parallelize def compute_mandelbrot_simd_parallel() -> Matrix:# create a matrix. Each element of the matrix corresponds to a pixelvar result = Matrix(xn, yn)let dx = (xmax - xmin) / xnlet dy = (ymax - ymin) / ynalias simd_width = dtype_simd_width[DType.f32]()@parameterfn _process_row(row:Int):var y = ymin + dy*rowvar x = xmin@parameterfn _process_simd_element[simd_width:Int](col: Int):let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, y) result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c)) x += simd_width*dx vectorize[simd_width, _process_simd_element](xn) parallelize[_process_row](yn)return resultmake_plot(compute_mandelbrot_simd_parallel())print("finished")
finished
Benchmarking
In this section we benchmark our sequential implementation against the parallel implementation. As you see, you get almost a 2x seedup. To get more pronounced speedups, try adjusting MAX_ITERS (e.g. 1000 or 1000) and/or the image size to (e.g. to 4096x4096).