Search code examples
pythonmatplotlibinteractivecolorbarmatplotlib-3d

Adding colorbar to static axes on 3D interactive matplotlib grid


I am creating an interactive 3D plot using matplotlib and am having some difficulty adding a functional colorbar. I gave the colorbar its own axes so that it would not rotate with the 3D grid. The initial plot and colorbar generate correctly, however, on the first interaction with the slider, the colorbar disappears. Upon the next interaction, the plotted data also disappears. This seems to be due to the cb.remove() command. If the cb.remove() command is not present, the colorbar will simply draw over itself upon each interaction, leaving the plot with many messy colorbars stacked on top of each other. I created the colorbar as its own object that should regenerate on each updated slider value, so I can't quite figure out why it would not regenerate, or why it would also remove the plotted data.

Here is some code to reproduce the issue.

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import numpy as np
%matplotlib notebook

class InteractivePlot:
    def __init__(self):
        self.fig = plt.figure(layout='constrained')
        self.ax = self.fig.add_subplot(111, projection='3d')
        #initalize plot
        self.cmap = cm.get_cmap('Blues')
        x = np.random.randint(1, 3, size=(4, 4)) 
        normalizer=Normalize(np.min(x),np.max(x))
        im=cm.ScalarMappable(norm=normalizer, cmap=self.cmap)
        
        self.cax_pos = [0.93, 0.15, 0.02, 0.7] # left, bottom, width, height
        self.cax = self.fig.add_axes(self.cax_pos)
        self.cb = self.fig.colorbar(im, cax=self.cax)

        self.X, self.Y = np.meshgrid(np.linspace(0,1,x.shape[1]), np.linspace(0,1,x.shape[0]))
        self.ax.contourf(x, self.X, self.Y, offset=0, cmap=self.cmap)
        #Create buttons to increment and decrement iteration slider
        self.button1 = Button(plt.axes([0.85, 0.01, 0.05, 0.05]), '-', hovercolor='white')
        self.button2 = Button(plt.axes([0.9, 0.01, 0.05, 0.05]), '+', hovercolor='white')
        
        slider_ax = plt.axes([0.15, 0.01, 0.6, 0.03])
        self.slider = Slider(slider_ax, 'Upper lim', 1, 10, valinit=3, valstep=1)
        self.slider.on_changed(self.update)
        self.button1.on_clicked(self.update)
        self.button2.on_clicked(self.update)
        

    def update(self, val):
        self.ax.clear()
        self.cb.remove()
        
        def increment_slider(event):
            self.slider.set_val(self.slider.val + self.slider.valstep)
        def decrement_slider(event):
            self.slider.set_val(self.slider.val - self.slider.valstep)
        self.button1.on_clicked(decrement_slider)
        self.button2.on_clicked(increment_slider)

        self.x = np.random.randint(1, self.slider.val, size=(4, 4)) 
        normalizer=Normalize(np.min(self.x),np.max(self.x))
        im=cm.ScalarMappable(norm=normalizer, cmap=self.cmap)
        self.ax.contourf(self.x, self.X, self.Y, offset=0, zdirs='z', cmap=self.cmap)
        self.cb = self.fig.colorbar(im, cax=self.cax)
        self.fig.canvas.draw_idle()
        plt.show()


Solution

  • I have solved the issue but will leave this up in case anyone else has a similar problem. For some reason, initializing the colorbar axes to the class causes this issue when they are called from the update function. Instead, you have to define the axes and their location independently in the update function. I'm not sure calling self.cax would cause this problem, but here is some updated code that works as desired

    import matplotlib.pyplot as plt
    from matplotlib.widgets import Slider, Button
    import matplotlib.cm as cm
    from matplotlib.colors import Normalize
    import numpy as np
    %matplotlib notebook
    
    class InteractivePlot:
        def __init__(self):
            self.fig = plt.figure(layout='constrained')
            self.ax = self.fig.add_subplot(111, projection='3d')
            #initalize plot
            self.cmap = cm.get_cmap('Blues')
            x = np.random.randint(1, 3, size=(4, 4)) 
            normalizer=Normalize(np.min(x),np.max(x))
            im=cm.ScalarMappable(norm=normalizer, cmap=self.cmap)
            
            cax_pos = [0.93, 0.15, 0.02, 0.7] # left, bottom, width, height
            cax = self.fig.add_axes(cax_pos)
            self.cb = self.fig.colorbar(im, cax=cax)
    
            self.X, self.Y = np.meshgrid(np.linspace(0,1,x.shape[1]), np.linspace(0,1,x.shape[0]))
            self.ax.contourf(x, self.X, self.Y, offset=0, cmap=self.cmap)
            #Create buttons to increment and decrement iteration slider
            self.button1 = Button(plt.axes([0.85, 0.01, 0.05, 0.05]), '-', hovercolor='white')
            self.button2 = Button(plt.axes([0.9, 0.01, 0.05, 0.05]), '+', hovercolor='white')
            
            slider_ax = plt.axes([0.15, 0.01, 0.6, 0.03])
            self.slider = Slider(slider_ax, 'Upper lim', 1, 10, valinit=3, valstep=1)
            self.slider.on_changed(self.update)
            self.button1.on_clicked(self.update)
            self.button2.on_clicked(self.update)
            
    
        def update(self, val):
            self.ax.clear()
            self.cb.remove()
            
            def increment_slider(event):
                self.slider.set_val(self.slider.val + self.slider.valstep)
            def decrement_slider(event):
                self.slider.set_val(self.slider.val - self.slider.valstep)
            self.button1.on_clicked(decrement_slider)
            self.button2.on_clicked(increment_slider)
    
            self.x = np.random.randint(1, self.slider.val, size=(4, 4)) 
            normalizer=Normalize(np.min(self.x),np.max(self.x))
            im=cm.ScalarMappable(norm=normalizer, cmap=self.cmap)
            self.ax.contourf(self.x, self.X, self.Y, offset=0, zdirs='z', cmap=self.cmap)
            cax_pos = [0.93, 0.15, 0.02, 0.7] # left, bottom, width, height
            cax = self.fig.add_axes(cax_pos)
            self.cb = self.fig.colorbar(im, cax=cax)
            self.fig.canvas.draw_idle()
            plt.show()