Search code examples
pythonmatplotlibheatmapcolorbarscatter

Matplotlib scatterplot with standardized colormap across subplots


I Have two sets of data that I would like to compare. Each set of data has x and y values along with some z values for each x,y point. The distribution of z values between the two data sets may overlap each other, but generally will have portions that do not overlap. In this case [1,4] and [2,6]. I would like for the color scheme to take this into account so I can visualize these differences when comparing the two sets. I would eventually like to use a colorbar added to this figure as well. Some example code:

# Fake Values
vals1 = np.array([[1,1,1],[2,2,4]])
vals2 = np.array([[1,1,2],[2,2,6]])

fig, ax = plt.subplots(1,2, constrained_layout=True)
g1 = ax[0].scatter(x=vals1[:,0], y=vals1[:,1], c=vals1[:,2], cmap='RdBu')
g2 = ax[1].scatter(x=vals2[:,0], y=vals2[:,1], c=vals2[:,2], cmap='RdBu')
fig.colorbar(g2)

This gives me the following:

Scatterplot with bad color scheme

As you can see the z (c?) values are not standardized between the subplots. Any help would be greatly appreciated.


Solution

  • You can set the vmin/vmax of both plots as the data's global min/max.

    Either set the vmin/vmax params individually:

    vmin = np.vstack([vals1,vals2]).min()
    vmax = np.vstack([vals1,vals2]).max()
    
    fig, ax = plt.subplots(1,2, constrained_layout=True)
    g1 = ax[0].scatter(x=vals1[:,0], y=vals1[:,1], c=vals1[:,2], vmin=vmin, vmax=vmax, cmap='RdBu')
    g2 = ax[1].scatter(x=vals2[:,0], y=vals2[:,1], c=vals2[:,2], vmin=vmin, vmax=vmax, cmap='RdBu')
    fig.colorbar(g2)
    

    Or create a matplotlib.colors.Normalize() instance and use it for the norm param:

    norm = mcolors.Normalize(
        vmin=np.vstack([vals1,vals2]).min(),
        vmax=np.vstack([vals1,vals2]).max(),
    )
    
    fig, ax = plt.subplots(1,2, constrained_layout=True)
    g1 = ax[0].scatter(x=vals1[:,0], y=vals1[:,1], c=vals1[:,2], norm=norm, cmap='RdBu')
    g2 = ax[1].scatter(x=vals2[:,0], y=vals2[:,1], c=vals2[:,2], norm=norm, cmap='RdBu')
    fig.colorbar(g2)
    

    synced vmin/vmax