Search code examples
pythonmatplotlibplotdata-visualizationcolorbar

Two colorbars on two subplots, same figure


I am trying to make a matplotlib plot with two subplots, and one colorbar to the right of each subplot. Here is my code currently:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mpl_toolkits.axes_grid1 import make_axes_locatable

X = tsne_out[:,0]
Y = tsne_out[:,1]
Z = tsne_out[:,2]

fig = plt.figure(figsize = (20,15))
ax1 = fig.add_subplot(221)
ax1.scatter(X, Y, c = material, s = df['Diameter (nm)']/4, cmap = plt.get_cmap('nipy_spectral', 11))
ax1.set_title("2D Representation", fontsize = 18)
ax1.set_xlabel("TSNE1", fontsize = 14)
ax1.set_ylabel("TSNE2", fontsize = 14)
ax1.set_xlim(-20,20)
ax1.set_ylim(-20,20)
ax1.set_xticks(list(range(-20,21,10)))
ax1.set_yticks(list(range(-20,21,10)))


cbar = fig.colorbar(cax, ticks=list(range(0,9)))
cbar.ax.tick_params(labelsize=15) 
cbar.ax.set_yticklabels(custom_ticks)  # horizontal colorbar


ax2 = fig.add_subplot(222, projection='3d')
ax2.scatter(X, Y, Z, c = material, s = df['Diameter (nm)']/4, cmap = plt.get_cmap('nipy_spectral', 11))
ax2.set_title("3D Representation", fontsize = 18)
ax2.set_xlabel("TSNE1", fontsize = 14)
ax2.set_ylabel("TSNE2", fontsize = 14)
ax2.set_zlabel("TSNE3", fontsize = 14)
ax2.set_xlim(-20,20)
ax2.set_ylim(-20,20)
ax2.set_zlim(-20,20)
ax2.set_xticks(list(range(-20,21,10)))
ax2.set_yticks(list(range(-20,21,10)))
ax2.set_zticks(list(range(-20,21,10)))

cbar = fig.colorbar(cax, ticks = list(range(0,9)))
cbar.ax.tick_params(labelsize=15) 
cbar.ax.set_yticklabels(custom_ticks)

This provides the following figure: Matplotlib produced figure

My question is: why does the first colorbar not show my custom ticks and how do I fix this?


Solution

  • The issue seems to be that ScalarMappable objects seem to be able to have at most one colorbar associated with them. When you draw the second colorbar with the same ScalarMappable, the original colorbar is unlinked and the previous settings are lost for the first colorbar.

    Your code is missing some details (in particular, the definition of cax), so you either have to create two separate mappables, or directly use what each scatter call gives you. Furthermore, I'd be explicit about where you want to get your colorbars to be inserted.

    An example fix, assuming that cax was really meant to refer to your scatter plots:

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    
    X = np.random.rand(100) * 40 - 20
    Y = np.random.rand(100) * 40 - 20
    Z = np.random.rand(100) * 40 - 20
    C = np.random.randint(1,8,100)
    custom_ticks = list('ABCDEFGH')
    
    fig = plt.figure(figsize = (20,15))
    ax1 = fig.add_subplot(121)
    sc1 = ax1.scatter(X, Y, c = C, cmap='viridis') # use this mappable
    ax1.set_title("2D Representation", fontsize = 18)
    ax1.set_xlabel("TSNE1", fontsize = 14)
    ax1.set_ylabel("TSNE2", fontsize = 14)
    ax1.set_xlim(-20,20)
    ax1.set_ylim(-20,20)
    ax1.set_xticks(list(range(-20,21,10)))
    ax1.set_yticks(list(range(-20,21,10)))
    
    
    cbar = fig.colorbar(sc1, ax=ax1, ticks=list(range(0,9))) # be explicit about ax1
    cbar.ax.tick_params(labelsize=15) 
    cbar.ax.set_yticklabels(custom_ticks)
    
    ax2 = fig.add_subplot(122, projection='3d')
    sc2 = ax2.scatter(X, Y, Z, c=C, cmap='viridis') # next time use this one
    ax2.set_title("3D Representation", fontsize = 18)
    ax2.set_xlabel("TSNE1", fontsize = 14)
    ax2.set_ylabel("TSNE2", fontsize = 14)
    ax2.set_zlabel("TSNE3", fontsize = 14)
    ax2.set_xlim(-20,20)
    ax2.set_ylim(-20,20)
    ax2.set_zlim(-20,20)
    ax2.set_xticks(list(range(-20,21,10)))
    ax2.set_yticks(list(range(-20,21,10)))
    ax2.set_zticks(list(range(-20,21,10)))
    
    cbar = fig.colorbar(sc2, ax=ax2, ticks=list(range(0,9))) # sc1 here is the bug
    cbar.ax.tick_params(labelsize=15) 
    cbar.ax.set_yticklabels(custom_ticks)
    
    plt.show()
    

    This produces the following:

    created figure, fixed colorbar

    Note that I created an MCVE for you, and I simplified a few things, for instance the number of subplots. The point is that the colorbar settings stick now that they use separate mappables.


    Another option is to create your colorbars first (using the same ScalarMappable if you want to), then customize both afterwards:

    sc = ax1.scatter(X, Y, c = C, cmap='viridis')
    cbar1 = fig.colorbar(sc, ax=ax1, ticks=np.arange(0,9))
    ax2.scatter(X, Y, Z, c=C, cmap='viridis')
    cbar2 = fig.colorbar(sc, ax=ax2, ticks=np.arange(0,9)) # sc here too
    
    for cbar in cbar1,cbar2:
        cbar.ax.tick_params(labelsize=15) 
        cbar.ax.set_yticklabels(custom_ticks)
    

    The fact that the above works may suggest that the original behaviour is a bug.