Search code examples

Is there a way to improve the performance of this fractal calculation algorithm?

Yesterday I came across the new 3Blue1Brown video about Newton's fractal and I was really mesmerized by his live representation of the fractal. (Here's the video link for anybody interested, it's at 13:40:

I wanted to have a go at it myself and tried to code it in python (I think he uses python too).

I spent a few hours trying to improve my naive implementation and got to a point where I just don't know how could I make it faster.

The code looks like this:

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from time import time

def print_fractal(state):
    fig = plt.figure(figsize=(8, 8))
    gs = GridSpec(1, 1)
    axs = [fig.add_subplot(gs[0, 0])]

def get_function_value(z):
    return z**5 + z**2 - z + 1

def get_function_derivative_value(z):
    return 5*z**4 + 2*z - 1

def check_distance(state, roots):
    roots2 = np.zeros((roots.shape[0], state.shape[0], state.shape[1]), dtype=complex)
    for r in range(roots.shape[0]):
        roots2[r] = np.full((state.shape[0], state.shape[1]), roots[r])
    dist_2 = np.abs((roots2 - state))
    original_state = np.argmin(dist_2, axis=0) + 1
    return original_state

def static():
    time_start = time()
    s = 4
    c = [0, 0]
    n = 800
    polynomial = [1, 0, 0, 1, -1, 1]
    roots = np.roots(polynomial)
    state = np.transpose((np.linspace(c[0] - s/2, c[0] + s/2, n)[:, None] + 1j*np.linspace(c[1] - s/2, c[1] + s/2, n)))
    n_steps = 15
    time_setup = time()
    for _ in range(n_steps):
        state -= (get_function_value(state) / get_function_derivative_value(state))
    time_evolution = time()
    original_state = check_distance(state, roots)
    time_check = time()
    print("{0:<40}".format("Time to setup the initial configuration:"), "{:20.3f}".format(time_setup - time_start))
    print("{0:<40}".format("Time to evolve the state:"), "{:20.3f}".format(time_evolution - time_setup))
    print("{0:<40}".format("Time to check the closest roots:"), "{:20.3f}".format(time_check - time_evolution))

An average output looks like this:

Time to setup the initial configuration: 0.004

Time to evolve the state: 0.796

Time to check the closest roots: 0.094

It's clear that it's the evolution part that bottlenecks the process. It's not "slow", but I think it's not enough to render something live like in the video. I already did what I could by using numpy vectors and avoiding loops but I guess it's not enough. What other tricks could be applied here?

Note: I tried using numpy.polynomials.Polynomial class to evaluate the function, but it was slower than this version.


    1. I got an improvement (~40% faster) by using single complex (np.complex64) precision.
    state = np.transpose((np.linspace(c[0] - s/2, c[0] + s/2, n)[:, None] 
                          + 1j*np.linspace(c[1] - s/2, c[1] + s/2, n)))
    state = state.astype(np.complex64)
    1. 3Blue1Brown added this link in the description: You can take a look how it was done there (sidenote: author of this pen used single precision as well)