Search code examples
pythonmatplotlibseabornscatter-plotcolormap

Seaborn plot with colorbar, centered around 0


I am trying to change the default behaviour of seaborn by adding a colormap (a continuous color palette) instead of using the hue argument, which creates bins from a continuous variable. I have found the following code to work, however, I would like to add one more option, to center the color bar at 0, that is 0 gets the color white, and the colors diverge from zero to negative/positive.

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

y=np.random.normal(30,30,100)
x=np.random.uniform(0,50,100)

s=sns.scatterplot(
    y=y,
    x=x,
    hue=y,
    size=y,
    palette='RdBu',
    sizes=(50,50)
)
norm=plt.Normalize(y.min(),y.max())
sm=plt.cm.ScalarMappable(cmap="RdBu",norm=norm)
sm.set_array([])
s.get_legend().remove()
s.figure.colorbar(sm)

enter image description here

As can be seen from the image 0 gets a slightly reddish color, because the data is not symmetric about zero. How can I center the colormap around 0? I am completely fine with an inflated colormap from say -80 to 80 (because of the asymmetry) if the center is at 0.


Solution

  • Using the c, norm, and cmap key-word arguments which are passed through from seaborn to matplotlib.axes.Axes.scatter (used to colour the points instead of palette) and create a mcolors.TwoSlopeNorm to create the normalisation centred around zero you can generate the plot like so:

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import matplotlib.colors as mcolors
    
    fig, ax = plt.subplots()
    
    y=np.random.normal(30,30,100)
    x=np.random.uniform(0,50,100)
    
    vcenter = 0
    vmin, vmax = y.min(), y.max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = cm.RdBu
    
    s=sns.scatterplot(
        y=y,
        x=x,
        c=y,
        norm=normalize,
        cmap=colormap,
        ax=ax,
    )
    
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(y)
    fig.colorbar(scalarmappaple)
    

    enter image description here