Unable to optimize Fractal code with Numba

I am writing code to visualize Mandelbrot sets and other fractals. Below is a snippet of the code that is being run. The code runs perfectly fine as is, but I am trying to optimize it to make higher resolution images faster. I've tried using caching on fractal(), along with @jit and @njit from Numba. Caching resulted in a crash (from memory overflow I'm assuming) and @jit just slows down the execution of my program by a factor of 6. I am also aware of the many mathematical ways there are of making my code run faster, as I've seen on the Wikipedia page, but I would like to see if I can get one of the above methods or some alternative to work.

For creating multiple images in a row (to make a zoom animation, like this one) I've implemented multiprocessing (which seems to run 9 processes at once) but I don't know how to implement the same in the creation of a single high resolution image.

Here is my code snippet:

import numpy as np
import cv2
import cmath
import math

# pick the fractal
def fractal(z,c):
# Mandelbrot
    if fractal_type == 0:
        return z**d + c
# Burning Ship
    if fractal_type == 1:
        return complex(abs(z.real), abs(z.imag))**d + c

#naive escape time algorithm
def naive_escape(arr):
    h = arr[0]
    w = arr[1]
    d = arr[2]
    zoom = pow(1.5, arr[3]) * pow(10,int(np.log10(h)))
    x_cen = arr[4]
    y_cen = arr[5]

    for i in range(w):
        sys.stdout.write("\r{0:03}%".format(np.round(i/w * 100, 4)))

        for j in range(h):
            it = 0
            cx = i - int(w/2)
            cy = j - int(h/2)
            sx = (cx / (zoom)) + x_cen
            sy = (cy / (zoom)) - y_cen

            c = complex(sx,sy)
            z = complex(0.0,0.0)

            while ((z.real)**2 + (z.imag)**2 <= 2**d) and (it < max_it):
                z = fractal(z,c)
                it += 1

            img[j][i] = color_dict[it]


    name = "fractal"

    cv2.imwrite("{}.png".format(name), img)
    print("\n{} created!\n".format(name), fractal_type)

I should clarify that the reason the coloring function naive_escape() takes an array input is because of my implementation of multiprocessing. Since map() in multiprocessing only allows us to map the function with one input, I just pass an array with all my input values.

The code pasted above is a snippet from a much bigger file, so please excuse me for any syntax errors.

Any help in making my code faster would be greatly appreciated!


  • This older answer deals specifically with vectorization, but some additional optimization can be done.

    You can start with Numpy vectorization, convenient but not really fast:

    def mandelbrot_numpy(c: complex, max_it: int) -> int:
        z = c
        for i in range(max_it):
            if abs(z) > 2:
                return i
            z = z**2 + c
        return 0

    Or Numba vectorization, which improves speed by an order of magnitude:

    @nb.vectorize([nb.u2(nb.c16, nb.i8)])
    def mandelbrot_numba(c: complex, max_it: int) -> int:
        z = c
        for i in range(max_it):
            if abs(z) > 2:
                return i
            z = z**2 + c
        return 0

    Then you can apply some of the usual optimizations:

    @nb.vectorize([nb.u2(nb.c16, nb.u2)])
    def mandelbrot_numba_opt(c: complex, max_it: int) -> int:
        x = cx = c.real
        y = cy = c.imag
        for i in range(max_it):
            x2 = x*x
            y2 = y*y
            if x2 + y2 > 4:
                return i
            y = (x+x)*y + cy
            x = x2 - y2 + cx
        return 0

    And you can also parallelize it (by rows in this example):

    @nb.njit([nb.u2[:,:](nb.c16[:,:], nb.u2)], parallel=True)
    def mandelbrot_parallel(c: np.ndarray, max_it: int) -> np.ndarray:
        result = np.zeros_like(c, dtype=nb.u2)
        for row in nb.prange(len(c)):
            result[row] = mandelbrot_numba_opt(c[row], max_it)
        return result

    Some timings on a 1000x1000 array:

    N = 1000
    x = np.linspace(-2, 2, N).reshape((1, -1))
    y = x.T
    c = x + 1j * y
    %timeit mandelbrot_numpy(c, 99)
    1.59 s ± 40.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit mandelbrot_numba(c, 99)
    100 ms ± 406 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    %timeit mandelbrot_numba_opt(c, 99)
    35 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    %timeit mandelbrot_parallel(c, 99)
    10.9 ms ± 64.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)