Search code examples
pythonmatplotlibipywidgets

ipywidget with matplotlib figure always shows two axes


I am trying to create a ipywidget interface with a matplotlib figure that updates upon changing a slider. It works in principle, but it always creates an extra figure.

Here's the code:

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display, clear_output

# make a grid

x = np.linspace(0, 5, 100)
X, Y = np.meshgrid(x, x)

# create the layout with slider

out = widgets.Output(layout=widgets.Layout(height='300px', width = '400px', border='solid'))
slider = widgets.IntSlider(value=1, min=1, max=5)
w = widgets.VBox(children=(out, slider))

# axes to plot into

ax = plt.axes()

display(w)

def update(value):
    i = slider.value
    Z = np.exp(-(X / i)**2 - (Y / i)**2)
    ax.pcolormesh(x, x, Z, vmin=0, vmax=1, shading='auto')
    with out:
        clear_output(wait=True)
        display(ax.figure)

slider.observe(update)
update(None)

And here's the undesired output

output

The widget works, and only the upper output is updated, but I do not understand why the lower output also exists or how to get rid of it. Am I missing something obvious?


Solution

  • You can use the widget backend, %matplotlib widget, which I think is designed for this. You'll need to put %matplotlib widget at the top (or before matplotlib stuff is brought in).

    Update: Also some guidance form matplotlib here and below, emphasis added.

    Note

    To get the interactive functionality described here, you must be using an interactive backend. The default backend in notebooks, the inline backend, is not. backend_inline renders the figure once and inserts a static image into the notebook when the cell is executed. Because the images are static, they can not be panned / zoomed, take user input, or be updated from other cells.

    Your example can be reduced to:

    %matplotlib widget
    import numpy as np
    import matplotlib.pyplot as plt
    from ipywidgets import widgets
    from IPython.display import display, clear_output
    
    x = np.linspace(0, 5, 100)
    X, Y = np.meshgrid(x, x)
    
    slider = widgets.IntSlider(value=1, min=1, max=5)
    
    ax = plt.axes()
    display(slider)
    def update(value):
        i = slider.value
        Z = np.exp(-(X / i)**2 - (Y / i)**2)
        ax.pcolormesh(x, x, Z, vmin=0, vmax=1, shading='auto')
    
    slider.observe(update)
    update(None)
    

    enter image description here