Search code examples
pythonmatplotlibsubplotcolorbar

One colorbar to indicate data range for multiple subplots using matplotlib?


I have saw many similar questions like this one. However, the colorbar actually indicates the data range of last subplot, as is verified by the following code:

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(19680801)


fig, axs = plt.subplots(2, 1)
cmaps = ['RdBu_r', 'viridis']
for row in range(2):
    ax = axs[row]
    if row == 0:
        pcm = ax.pcolormesh(np.random.random((20, 20)) * (-100),
                            cmap=cmaps[0])
    elif row == 1:
            pcm = ax.pcolormesh(np.random.random((20, 20)) * 100,
                            cmap=cmaps[0])
fig.colorbar(pcm, ax=axs)
plt.show()

Output

The colobar only indicates the data range of second sub-figure. The data in the first sub-figure is actually negative, which is not shown in the colorbar.

fig, axs = plt.subplots(2, 1)
cmaps = ['RdBu_r', 'viridis']
for row in range(2):
    ax = axs[row]
    if row == 0:
        pcm = ax.pcolormesh(np.random.random((20, 20)) * (-100),
                            cmap=cmaps[0])
    elif row == 1:
            pcm = ax.pcolormesh(np.random.random((20, 20)) * 100,
                            cmap=cmaps[0])
    fig.colorbar(pcm, ax=ax)
plt.show()

output

So how to make one colorbar shared by multiple subplots to indicate overall data range ?

The problem may be cause by fig.colorbar(pcm, ax=axs), where pcm is pointed to the second sub-figure, but I am not sure how to solve this problem.


Solution

  • Set the color limits to be the same...

    import matplotlib.pyplot as plt 
    import numpy as np 
    
    fig, axs = plt.subplots(2, 1)
    cmaps = ['RdBu_r', 'viridis']
    for row in range(2):
        ax = axs[row]
        mult = -100 if row == 0 else 100
        pcm = ax.pcolormesh(np.random.random((20, 20)) * mult,
                                cmap=cmaps[0], vmin=-150, vmax=150)
    fig.colorbar(pcm, ax=axs)
    plt.show()
    

    or equivalently you can specify a Normalize object:

    import matplotlib.pyplot as plt 
    import numpy as np 
    
    
    fig, axs = plt.subplots(2, 1)
    cmaps = ['RdBu_r', 'viridis']
    norm = plt.Normalize(vmin=-150, vmax=150)
    for row in range(2):
        ax = axs[row]
        mult = -100 if row == 0 else 100
        pcm = ax.pcolormesh(np.random.random((20, 20)) * mult,
                                cmap=cmaps[0], norm=norm)
    fig.colorbar(pcm, ax=axs)
    plt.show()