Search code examples
pythonnumpyjupyter-notebookmatplotlib-animation

Convolution integral export as animation in jupyter


This example is taken from a tutorial and this post related to convolution integral.

I would like to show it in a jupyter notebook using animation from matplotlib. I had a look at this stack post. So far, the code looks like this:

import scipy.integrate
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()

def showConvolution(t0, f1, f2):
    # Calculate the overall convolution result using Simpson integration
    convolution = np.zeros(len(t))
    for n, t_ in enumerate(t):
        prod = lambda tau: f1(tau) * f2(t_-tau)
        convolution[n] = scipy.integrate.simps(prod(t), t)
    
    # Create the shifted and flipped function
    f_shift = lambda t: f2(t0-t)
    prod = lambda tau: f1(tau) * f2(t0-tau)

    # Plot the curves
    
    plt.subplot(211)
    plt.plot(t, f1(t), label=r'$f_1(\tau)$')
    plt.plot(t, f_shift(t), label=r'$f_2(t_0-\tau)$')
    plt.plot(t, prod(t), 'r-', label=r'$f_1(\tau)f_2(t_0-\tau)$')
    
    # plot the convolution curve
    plt.subplot(212)
    plt.plot(t, convolution, label='$(f_1*f_2)(t)$')
    
    # recalculate the value of the convolution integral at the current time-shift t0
    current_value = scipy.integrate.simps(prod(t), t)
    plt.plot(t0, current_value, 'ro')  # plot the point

Fs = 50  # our sampling frequency for the plotting
T = 5    # the time range we are interested in
t = np.arange(-T, T, 1/Fs)  # the time samples
f1 = lambda t: np.maximum(0, 1-abs(t))
f2 = lambda t: (t>0) * np.exp(-2*t)

t0 = np.arange(-2.0,2.0, 0.05)

fig = plt.figure(figsize=(8,3))
anim = animation.FuncAnimation(fig, showConvolution, frames=t0, fargs=(f1,f2),interval=80)
anim

Now I get a quite poor loking animation instead: enter image description here


Solution

  • I modified the SO answer in the question to an animation that works in jupyter and only required code for your code, and changed it to the axes format since I have no experience with pyplot format animations. The issue is due to the removal of the clearing of the graph. `axes[0].clear() is there to remove the previous graph element.

    import scipy.integrate
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    
    def showConvolution(t0,f1, f2):
        # Calculate the overall convolution result using Simpson integration
        convolution = np.zeros(len(t))
        for n, t_ in enumerate(t):
            prod = lambda tau: f1(tau) * f2(t_-tau)
            convolution[n] = scipy.integrate.simps(prod(t), t)
    
        # Create the shifted and flipped function
        f_shift = lambda t: f2(t0-t)
        prod = lambda tau: f1(tau) * f2(t0-tau)
    
        # Plot the curves
        axes[0].clear() # il
        axes[1].clear()
        
        axes[0].set_xlim(-5, 5)
        axes[0].set_ylim(0, 1.0)
        #axes[0].set_ymargin(0.05) # il
        axes[0].plot(t, f1(t), label=r'$f_1(\tau)$')
        axes[0].plot(t, f_shift(t), label=r'$f_2(t_0-\tau)$')
        #axes[0].fill(t, prod(t), color='r', alpha=0.5, edgecolor='black', hatch='//') # il
        axes[0].plot(t, prod(t), 'r-', label=r'$f_1(\tau)f_2(t_0-\tau)$')
        #axes[0].grid(True); axes[0].set_xlabel(r'$\tau$'); axes[0].set_ylabel(r'$x(\tau)$') # il
        #axes[0].legend(fontsize=10) # il
        #axes[0].text(-4, 0.6, '$t_0=%.2f$' % t0, bbox=dict(fc='white')) # il
    
        # plot the convolution curve
        axes[1].set_xlim(-5, 5)
        axes[1].set_ylim(0, 0.4)
        #axes[1].set_ymargin(0.05) # il
        axes[1].plot(t, convolution, label='$(f_1*f_2)(t)$')
    
        # recalculate the value of the convolution integral at the current time-shift t0
        current_value = scipy.integrate.simps(prod(t), t)
        axes[1].plot(t0, current_value, 'ro')  # plot the point
        #axes[1].grid(True); axes[1].set_xlabel('$t$'); axes[1].set_ylabel('$(f_1*f_2)(t)$') # il
        #axes[1].legend(fontsize=10) # il
        #plt.show() # il
    
    Fs = 50  # our sampling frequency for the plotting
    T = 5    # the time range we are interested in
    t = np.arange(-T, T, 1/Fs)  # the time samples
    f1 = lambda t: np.maximum(0, 1-abs(t))
    f2 = lambda t: (t>0) * np.exp(-2*t)
    
    t0 = np.arange(-2.0,2.0, 0.05)
    
    fig = plt.figure(figsize=(8,3))
    axes= fig.subplots(2, 1)
    anim = animation.FuncAnimation(fig, showConvolution, frames=t0, fargs=(f1,f2),interval=80)
    
    #anim.save('animation.mp4', fps=30) # fps = frames per second
    #plt.show()
    
    from IPython.display import HTML
    
    plt.close()
    HTML(anim.to_html5_video())
    

    enter image description here