Search code examples
pythonmatplotlibseaborn

Scatterplot with point colors representing a continuous variable in seaborn FacetGrid


I am trying to generate multi-panel figure using seaborn in python and I want the color of the points in my multi-panel figure to be specified by a continuous variable. Here's an example of what I am trying to do with the "iris" dataset:

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
iris = sns.load_dataset('iris')

g = sns.FacetGrid(iris, col = 'species', hue = 'petal_length', palette = 'seismic')
g = g.map(plt.scatter, 'sepal_length', 'sepal_width', s = 100, alpha = 0.5)
g.add_legend()

This makes the following figure: iris_continuous

Which is nice, but the legend is way too long. I'd like to sample out like 1/4 of these values (ideally) or barring that display a colorbar instead. For instance, something like this might be acceptable, but I'd still want to split it over the three species.

plt.scatter(iris.sepal_length, iris.sepal_width, alpha = .8, c = iris.petal_length, cmap = 'seismic')
cbar = plt.colorbar()

one panel

Any idea about how I can get the best of both of these plots?

Edit: This topic seems like a good start.

https://github.com/mwaskom/seaborn/issues/582

Somehow, for this user, simply appending plt.colorbar after everything else ran seemed to somehow work. Doesn't seem to help in this case though.


Solution

  • Since you were asking about a legend for the scatter, one may adapt @mwaskom's solution to produce a legend with scatter points like so:

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    iris = sns.load_dataset('iris')
    
    g = sns.FacetGrid(iris, col='species', palette = 'seismic')
    
    def facet_scatter(x, y, c, **kwargs):
        kwargs.pop("color")
        plt.scatter(x, y, c=c, **kwargs)
    
    vmin, vmax = 0, 7
    cmap = plt.cm.viridis
    
    norm=plt.Normalize(vmin=vmin, vmax=vmax)
    
    g = g.map(facet_scatter, 'sepal_length', 'sepal_width', "petal_length",
              s=100, alpha=0.5, norm=norm, cmap=cmap)
    
    # Make space for the colorbar
    g.fig.subplots_adjust(right=.9)
    
    lp = lambda i: plt.plot([], color=cmap(norm(i)), marker="o", ls="", ms=10, alpha=0.5)[0]
    labels = np.arange(0,7.5,0.5)
    h = [lp(i) for i in labels]
    g.fig.legend(handles=h, labels=labels, fontsize=9)
    
    plt.show()
    

    enter image description here