Search code examples
pythonmatplotlibscatter-plot

Scatter plot of points from several groups with legend


I need to make a scatter plot of two variables in a data frame. The observations come from three groups, so I want to use three different colors. I also want to include a legend on the plot. Here is a small example.

import pandas as pd
import matplotlib.pyplot as plt

d = {'x': [1, 2, 3, 4, 5, 6, 7], 'y': [3, 5, 7, 2, 4, 1, 8], 'grp': ['a','a','b','a','c','b','c']}
data = pd.DataFrame(d)

I learned that if each group is plotted separately, I can get what I need.

x1 = data.loc[data['grp'] == 'a', 'x']
y1 = data.loc[data['grp'] == 'a', 'y']
x2 = data.loc[data['grp'] == 'b', 'x']
y2 = data.loc[data['grp'] == 'b', 'y']
x3 = data.loc[data['grp'] == 'c', 'x']
y3 = data.loc[data['grp'] == 'c', 'y']
plt.scatter(x1, y1, label = 'a')
plt.scatter(x2, y2, label = 'b')
plt.scatter(x3, y3, label = 'c')
plt.legend()
plt.show()

enter image description here

This approach seems to be a little inefficient, especially when the number of groups is big. Another approach I found from stackoverflow is the following.

from matplotlib.colors import ListedColormap
values = data['grp'].replace(['a','b','c'], [0, 1, 2])
colors = ListedColormap(['r','b','g'])
scatter = plt.scatter(data['x'], data['y'], c=values, cmap=colors)
plt.legend(handles=scatter.legend_elements()[0], labels=['a', 'b', 'c'])
plt.show()

enter image description here

This is easier since you don't have to do each group separately. However, I don't quite understand handles=scatter.legend_elements()[0]. Is there an easy and intuitive way of doing this? I have been a R user before and this task can be easily done in ggplot where everything seems to be handled automatically. Thanks for the help!


Solution

  • You can use the pandas DataFrame.groupby method to split the groups by the "grp" column. The plot can be made by looping through groups and plotting them individually.

    import pandas as pd
    import matplotlib.pyplot as plt
    
    plt.close("all")
    
    d = {"x": [1, 2, 3, 4, 5, 6, 7], 
         "y": [3, 5, 7, 2, 4, 1, 8], 
         "grp": ["a","a","b","a","c","b","c"]}
    df = pd.DataFrame(d)
    
    fig, ax = plt.subplots()
    for group, data in df.groupby("grp"):
        ax.scatter(data.x, data.y, color=colors[group], label=group)
    ax.legend()
    fig.show()
    

    If you want to use the seaborn package, then you can use the scatterplot function and pass the "grp" to the hue argument to color them by group.

    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    plt.close("all")
    
    d = {"x": [1, 2, 3, 4, 5, 6, 7], 
         "y": [3, 5, 7, 2, 4, 1, 8], 
         "grp": ["a","a","b","a","c","b","c"]}
    df = pd.DataFrame(d)
    colors = {"a":"r", "b":"g", "c":"b"}
    
    fig, ax = plt.subplots()
    sns.scatterplot(data=df, x="x", y="y", hue="grp", ax=ax, palette=colors)
    fig.show()
    

    If you have a large number of groups, then you probably don't want to specify the colors manually. In that case, don't pass the color argument to the matplotlib-only version or the palette to the seaborn version.