Search code examples
pythonanimationmatplotlibnumerical-methods

Animating bisection method with matplotlib animation library


I am interested in math demonstrations. Currently I am working on visualizing numerical methods in python, in particular the bisection method. Below is the code I have written so far.

import matplotlib.pyplot as plt 
import matplotlib.animation as animation
import numpy as np 

def sgn(x):
    if x > 0:
        return 1
    elif x < 0:
        return -1
    else:
        return 0

def bisect(f,a,b):
    fa = f(a)
    fb = f(b)
    p = a+(b-a)/2
    fp = f(p)
    if sgn(fa) == sgn(fp):
        return p, fp, b, fb
    else:
        return a, fa, p, fp

def f(x):
    return x**2-3

a, b = 1, 2

plt.figure()
plt.subplot(111)

a, fa, b, fb = bisect(f,a,b)
vf = np.vectorize(f)
x = np.linspace(a,b)
y = vf(x)

plt.plot(x, y, color='blue')
plt.plot([a,a], [0,fa], color='red', linestyle="--")
plt.plot([b,b], [0,fb], color='red', linestyle="--")
plt.grid()
plt.show()

I have three problems I wish to solve. First, I want to be able to call the bisect function multiple times and each time I would like to redraw the plot with the new data. Second, I would like to restart the animation after applying the bisect function some specified number of times. Third, I would like to retain the original axes of the figure before the bisection method is called i.e. I would like to keep the x-range as [1,2] and the y-range as $[-2,1]$. Any help will be much appreciated.


Solution

  • I found a solution to my problems through much trial and error.

    import matplotlib.pyplot as plt 
    from matplotlib import animation
    import numpy as np 
    
    def sgn(x):
        if x > 0:
            return 1
        elif x < 0:
            return -1
        else:
            return 0
    
    def bisect(f,a,b):
        fa = f(a)
        fb = f(b)
        p = a+(b-a)/2
        fp = f(p)
        if sgn(fa) == sgn(fp):
            return p, b
        else:
            return a, p
    
    def bisection_method(f,a,b,n):
        for i in range(n):
            a,b = bisect(f,a,b)
        return a,b
    
    def f(x):
        return x**2-3
    
    xmin, xmax = 1, 2
    yrange = f(xmin), f(xmax)
    ymin, ymax = min(yrange), max(yrange) 
    vf = np.vectorize(f)
    x = np.linspace(xmin,xmax)
    y = vf(x)
    epsilon = 0.1
    # Initialize figure
    fig = plt.figure()
    ax = plt.axes(xlim=(xmin-epsilon,xmax+epsilon), ylim=(ymin,ymax))
    curve, = ax.plot([],[], color='blue')
    left, = ax.plot([],[],color='red')
    right, = ax.plot([],[],color='red')
    
    # Figure reset between frames
    def init():
        left.set_data([],[])
        right.set_data([],[])
        curve.set_data([],[])
        return left, right, curve,
    
    # Animation of bisection
    def animate(i):
        a, b = bisection_method(f,xmin,xmax,i)
        left.set_data([a,a],[ymin,ymax])
        right.set_data([b,b],[ymin,ymax])
        curve.set_data(x,y)
        return left, right, curve,
    
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=15, interval=700, blit=True)
    
    plt.grid()
    plt.show()