Search code examples
pythonmatplotliblegendcolorbarscatter

Add colorbar as legend to matplotlib scatterplot (multiple subplots, multiple scatters)


I have several subplots to which I want to add a single colorbar. Each subplot consists of 7 scatters. I found advise on how to add colorbars, but they are mostly related to the value of each scatter-point and not to the row itself.

Representative sample code:

import numpy as np
from matplotlib import pyplot as plt

x = range(50)
scales = np.linspace(0, 2, 7)
locs = range(4)
cmap = plt.get_cmap("Spectral")
for s_plot in range(4):
    plt.subplot(2, 2, s_plot+1)
    color = iter(cmap(np.linspace(0, 1, len(scales))))
    for scale in scales:
        c = next(color)
        y = np.random.normal(loc=locs[s_plot], scale=scale, size=50)
        plt.scatter(x, y, c=c, s=5)
        plt.title("Mean = {:d}".format(locs[s_plot]))
plt.subplots_adjust(hspace=0.4)
plt.show()

The above example gives: enter image description here

My desired colorbar looks like this (fake, to be placed next to the plot):

enter image description here

So the colorbar does not depict the value of my scatterpoints, but rather the different "rows" (in this case: different scales) that are iterated through. In the example that would help match the points to the scales.

What I tried is a simple

plt.colorbar()

which is called once after finishing each subplot. But I get TypeError: You must first set_array for mappable Also, since it is the different scales I want to create the colormap for, I also tried

plt.colorbar(scales) 

which returns: AttributeError: 'numpy.ndarray' object has no attribute 'autoscale_None'.

I am currently lacking orientation on how to proceed on this. Edit: I was marked as possible duplicate of matplotlib colorbar for scatter. I found that question already, but it didn't help with my problem. In my case, I need a colormap that is independent of a z-value, but will only indicate the "row number" or "scatter-row" or however you want to call it (equivalent to "lines" in a plt.plot).


Solution

  • A colorbar needs a ScalarMappable as input. So if none of the things you create in your plot is suitable for that, you may create it yourself.

    import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib.cm import ScalarMappable
    
    x = range(50)
    scales = np.linspace(0, 2, 7)
    locs = range(4)
    cmap = plt.get_cmap("Spectral")
    norm = plt.Normalize(scales.min(), scales.max())
    
    fig, axes = plt.subplots(2,2, constrained_layout=True, sharey=True)
    
    for s_plot, ax in enumerate(axes.flat):
        for scale in scales:
            y = np.random.normal(loc=locs[s_plot], scale=scale, size=50)
            sc = ax.scatter(x, y, c=[cmap(norm(scale))], s=5)
            ax.set_title("Mean = {:d}".format(locs[s_plot]))
    
    sm =  ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=axes[:,1])
    cbar.ax.set_title("scale")
    
    plt.show()
    

    enter image description here