Search code examples
pythonmatplotliblimitsubplotcolorbar

Set 'global' colorbar range for multiple matplotlib subplots of different ranges


I would like to plot data in subplots using matplotlib.pyplot in python. Each subplot will contain data of different ranges. I would like to plot them using pyplot.scatter, and use one single colorbar for the entire plot. Thus, the colorbar should encompass the entire range of the values in every subplot. However, when I use a loop to plot the subplots and call a colorbar outside of the loop, it only uses the range of values from the last subplot. A lot of examples available concern the sizing the position of the colorbar, so this answer (how to make one universal colorbar for multiple subplots) is not obvious.

I have the following self-contained example code. Here, two subplots are rendered, one that should be colored with frigid temperatures typical of Russia and the other with tropical temperatures of Brazil. However, the end result shows a colorbar that only ranges the tropical Brazilian temperatures, making the Russia subplot erroneous:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

core_list = ['Russia', 'Brazil']
core_depth = [0, 2, 4, 6, 8, 10]
lo = [-33, 28]
hi = [10, 38]
df = pd.DataFrame([], columns = ['Location', 'Depth', '%TOC', 'Temperature'])

#Fill df
for ii, name in enumerate(core_list):
    for jj in core_depth:
        df.loc[len(df.index)] = [name, jj, (np.random.randint(1, 20))/10, np.random.randint(lo[ii], hi[ii])]
#Russia data have much colder temperatures than Brazil data due to hi and lo

#Plot data from each location using scatter plots
fig, axs = plt.subplots(nrows = 1, ncols = 2, sharey = True)
for nn, name in enumerate(core_list):
    core_mask = df['Location'] == name
    data = df.loc[core_mask]
    plt.sca(axs[nn])
    plt.scatter(data['Depth'], data['%TOC'], c = data['Temperature'], s = 50, edgecolors = 'k')
    axs[nn].set_xlabel('%TOC')
    plt.text(1.25*min(data['%TOC']), 1.75, name)
    if nn == 0:
        axs[nn].set_ylabel('Depth')

cbar = plt.colorbar()
cbar.ax.set_ylabel('Temperature, degrees C')
#How did Russia get so warm?!? Temperatures and ranges of colorbar are set to last called location. 
#How do I make one colorbar encompass global temperature range of both data sets?

The output of this code shows that the temperatures in Brazil and Russia fall within the same range of colors: enter image description here

We know intuitively, and from glancing at the data, that this is wrong. So, how do we tell pyplot to plot this correctly?


Solution

  • The answer is straightforward using the vmax and vmin controls of pyplot.scatter. These must be set with a universal range of data, not just the data focused on in any single iteration of a loop. Thus, to change the code above:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    core_list = ['Russia', 'Brazil']
    core_depth = [0, 2, 4, 6, 8, 10]
    lo = [-33, 28]
    hi = [10, 38]
    
    df = pd.DataFrame([], columns = ['Location', 'Depth', '%TOC', 'Temperature'])
    #Fill df
    for ii, name in enumerate(core_list):
        for jj in core_depth:
            df.loc[len(df.index)] = [
                      name,
                      jj, 
                      (np.random.randint(1, 20))/10,
                      np.random.randint(lo[ii], hi[ii])
             ]
    #Russia data have much colder temperatures than Brazil data due to hi and lo
    
    #Plot data from each location using scatter plots
    fig, axs = plt.subplots(nrows = 1, ncols = 2, sharey = True)
    for nn, name in enumerate(core_list):
        core_mask = df['Location'] == name
        data = df.loc[core_mask]
        plt.sca(axs[nn])
        plt.scatter(
            data['Depth'],
            data['%TOC'],
            c=data['Temperature'],
            s=50,
            edgecolors='k',
            vmax=max(df['Temperature']),
            vmin=min(df['Temperature'])
         )
        axs[nn].set_xlabel('%TOC')
        plt.text(1.25*min(data['%TOC']), 1.75, name)
        if nn == 0:
            axs[nn].set_ylabel('Depth')
    
    cbar = plt.colorbar()
    cbar.ax.set_ylabel('Temperature, degrees C') 
    
    

    Now, the output shows a temperature difference between Russia and Brazil, which one would expect after a cursory glance at the data. The change that fixes this problem occurs within the for loop, however it references all of the data to find a max and min:

    plt.scatter(data['Depth'], data['%TOC'], c = data['Temperature'], s = 50, edgecolors = 'k', vmax = max(df['Temperature']), vmin = min(df['Temperature']) )

    enter image description here