Search code examples
pythonmatplotlibmandelbrot

How could I zoom in on a generated Mandelbrot set without consuming too many resources?


I am trying to make a Mandelbrot set display, with the following code:

import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['toolbar'] = 'None'

def mandelbrot(c, max_iter):
    z = 0
    for n in range(max_iter):
        if abs(z) > 2:
            return n
        z = z*z + c
    return max_iter

def mandelbrot_set(xmin, xmax, ymin, ymax, width, height, max_iter):
    r1 = np.linspace(xmin, xmax, width)
    r2 = np.linspace(ymin, ymax, height)
    n3 = np.empty((width, height))

    for i in range(width):
        for j in range(height):
            n3[i, j] = mandelbrot(r1[i] + 1j*r2[j], max_iter)
    return n3.T

# Settings
xmin, xmax, ymin, ymax = -2.0, 1.0, -1.5, 1.5
width, height = 800, 800
max_iter = 256

# Generate Mandelbrot set
mandelbrot_image = mandelbrot_set(xmin, xmax, ymin, ymax, width, height, max_iter)

# Window
fig = plt.figure(figsize=(5, 5))
fig.canvas.manager.set_window_title('Mandelbrot Set')
ax = fig.add_axes([0, 0, 1, 1])   # Fill the whole window
ax.set_axis_off()

# Show fractal
ax.imshow(mandelbrot_image, extent=(xmin, xmax, ymin, ymax), cmap='hot')
plt.show()

How could I zoom in on the fractal continuously, without taking up too many resources? I am running on a mid-range laptop, and it currently takes a long time to generate the fractal. Is there a faster way to do this when implementing a zoom feature?


Solution

  • You're using Python code to handle single NumPy numbers. That's the worst way. Would already be about twice as fast if you used Python numbers instead, using .tolist():

        r1 = np.linspace(xmin, xmax, width).tolist()
        r2 = np.linspace(ymin, ymax, height).tolist()
    

    But it's better to properly use NumPy, e.g., work on all pixels in parallel, keeping track of the values (and their indices) that still have abs ≤ 2:

    def mandelbrot_set(xmin, xmax, ymin, ymax, width, height, max_iter):
        r1 = np.linspace(xmin, xmax, width)
        r2 = np.linspace(ymin, ymax, height)
        n3 = np.empty(width * height)
    
        z = np.zeros(width * height)
        c = np.add.outer(r1, 1j*r2).flatten()
        i = np.arange(width * height)
        
        for n in range(max_iter):
            outside = np.abs(z) > 2
            n3[i[outside]] = n
            inside = ~outside
            z = z[inside]
            c = c[inside]
            i = i[inside]
            z = z*z + c
        n3[i] = max_iter
    
        return n3.reshape((width, height)).T
    

    This now takes me about 0.17 seconds instead of about 6.7 seconds with your original.