Search code examples
pythonmatplotlibplotseabornvisualization

Manually set values shown in legend for continuous variable of seaborn/matplotlib scatterplot


Is there a way to manually set the values shown in the legend of a seaborn (or matplotlib) scatterplot when the legend contains a continuous variable (hue)?

For example, in the plot below I might like to show the colors corresponding to values of [0, 1, 2, 3] rather than [1.5, 3, 4.5, 6, 7.5]

np.random.seed(123)
x = np.random.randn(500)
y = np.random.randn(500)
z = np.random.exponential(1, 500)

fig, ax = plt.subplots()
hue_norm = (0, 3)
sns.scatterplot(
    x=x,
    y=y,
    hue=z,
    hue_norm=hue_norm,
    palette='coolwarm',
)

ax.grid()
ax.set(xlabel="x", ylabel="y")
ax.legend(title="z")
sns.despine()

enter image description here


Solution

  • Seaborn creates its scatterplot a bit different than matplotlib. That way, the scatterplot can be customized in more ways. For the legend, Seaborn 0.13 employs custom Line2D elements (older Seaborn versions use PathCollections).

    The following approach:

    • replaces Seaborn's hue_norm=(0, 3) with an equivalent matplotlib norm
    • creates dummy Line2D elements to serve as legend handles
    • copies all properties (size, edgecolor, ...) of the legend handle created by Seaborn
    • then changes the marker color depending on the norm and colormap

    The approach might need some tweaks if your scatterplot differs. The code has been tested with Matplotlib 3.8.3 and Seaborn 0.13.2 (and 0.12.2).

    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    from matplotlib.lines import Line2D
    
    np.random.seed(123)
    x = np.random.randn(500)
    y = np.random.randn(500)
    z = np.random.exponential(1, 500)
    
    fig, ax = plt.subplots()
    hue_norm = plt.Normalize(vmin=0, vmax=3)
    sns.scatterplot(x=x, y=y, hue=z, hue_norm=hue_norm, palette='coolwarm', ax=ax)
    
    legend_keys = [0, 1, 2, 3]
    handles = [Line2D([], []) for _ in legend_keys]
    cmap = plt.get_cmap('coolwarm')
    for h, key in zip(handles, legend_keys):
        if type(ax.legend_.legend_handles[0]) == Line2D:
            h.update_from(ax.legend_.legend_handles[0])
        else:
            h.set_linestyle('')
            h.set_marker('o')
            h.set_markeredgecolor(ax.legend_.legend_handles[0].get_edgecolor())
            h.set_markeredgewidth(ax.legend_.legend_handles[0].get_linewidth())
        h.set_markerfacecolor(cmap(hue_norm(key)))
        h.set_label(f'{key}')
    ax.legend(handles=handles, title='z')
    sns.despine()
    plt.show()
    

    seaborn scatterplot with custom legend