Search code examples
pythonmatplotlibjupyter-notebookjupyter-lab

Updating multiple plots in Jupyter notebook when a slider value changes


I want to update multiple imshow plots in a jupyter notebook when an IntSlider value changes. What is wrong with by code?

Those are the versions I am using

import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

print( 'versions: ipywidgets = ', widgets.__version__)
print( '          matplotlib = ', matplotlib.__version__)
print( '          numpy      = ', np.__version__)

This is the corresponding output

versions: ipywidgets =  8.0.4
          matplotlib =  3.5.0
          numpy      =  1.20.3

And here is the code

def plot_image(ax, seed=0):
    np.random.seed(0)
    data2plot = np.random.rand(5,5)
    img = ax.imshow(data2plot)

fig = plt.figure( figsize=(12,6) )
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)

plot_image(ax1)
plot_image(ax2)

plt.show()

slider = widgets.IntSlider(value=0, min=0, max=100, step=1)

# callback function for the slider widget
def update(change):
    plot_image(ax1, seed=0)
    plot_image(ax2, seed=change.new)
    fig.canvas.draw()

# connect update function to slider widget using the .observe() method, observing changes in value attribute
slider.observe(update 'value')
slider

There is a slider, see the screenshot, and I can change its value, but it has no effect. What am I missing?

Screenshot


Solution

  • Putting together your code what Markus suggested with addressing two other issues with your implementation, this should work:

    %matplotlib ipympl
    import ipywidgets as widgets
    import matplotlib.pyplot as plt
    import matplotlib
    import numpy as np
    
    def plot_image(ax, seed=0):
        np.random.seed(seed)
        data2plot = np.random.rand(5,5)
        img = ax.imshow(data2plot)
    
    fig = plt.figure( figsize=(12,6) )
    ax1 = fig.add_subplot(1,2,1)
    ax2 = fig.add_subplot(1,2,2)
    
    plot_image(ax1)
    plot_image(ax2)
    
    plt.show()
    
    slider = widgets.IntSlider(value=0, min=0, max=100, step=1)
    
    # callback function for the slider widget
    def update(change):
        plot_image(ax1, seed=0)
        plot_image(ax2, seed=change.new)
        fig.canvas.draw()
    
    # connect update function to slider widget using the .observe() method, observing changes in value attribute
    slider.observe(update, 'value')
    slider
    

    Details: