Search code examples
pythonmatplotlibjupyter-notebookipywidgets

How to make a ipywidget button update a specific axis of a matplotlib figure?


How can I make a ipywidget button in a Jupyter notebook update a plot in a specific axis?

I already know how to make a button update a plot when using a single axis, like so:

import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np

btn = widgets.Button(description='Click')
display(btn)

output = widgets.Output()

def on_click_fn(obj):
    output.clear_output()
    values = np.random.rand(10)
    with output:
        plt.plot(values)
        plt.show()

btn.on_click(on_click_fn)
display(output)

In this example, clicking the button updates the plot and shows a new set of 10 random points. I thought it would be simple to extend this to updating a specific axis, and attempted the following:

import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np

btn = widgets.Button(description='Click')
display(btn)

output = widgets.Output()
fig, ax = plt.subplots(ncols=2)

def on_click_fn(obj):
    output.clear_output()
    values = np.random.rand(10)
    with output:
        ax[0].plot(values)
        plt.show()

btn.on_click(on_click_fn)
display(output)

However, clicking the button in this example does not seem to do anything. I tried different combinations of adding/removing the plt.show() call, using fig.draw() instead, using fig.canvas.draw_idle(), etc, without much success. What's the correct, least "hacky" way of accomplishing this?


Note: This question is only about how to make a button update a plot, like my first example, instead of making the button update a specific axis only.


Solution

  • with this code

    import ipywidgets as widgets
    from IPython.display import display
    import matplotlib.pyplot as plt
    import numpy as np
    %matplotlib widget
    
    btn = widgets.Button(description='Click')
    display(btn)
    
    output = widgets.Output()
    fig, ax = plt.subplots(ncols=2)
    
    def on_click_fn(obj):
        output.clear_output()
        values = np.random.rand(10)
        with output:
            ax[0].plot(values)
            plt.show()
    
    btn.on_click(on_click_fn)
    display(output)
    

    I got this output

    enter image description here