I need to make a scatter plot of two variables in a data frame. The observations come from three groups, so I want to use three different colors. I also want to include a legend on the plot. Here is a small example.
import pandas as pd
import matplotlib.pyplot as plt
d = {'x': [1, 2, 3, 4, 5, 6, 7], 'y': [3, 5, 7, 2, 4, 1, 8], 'grp': ['a','a','b','a','c','b','c']}
data = pd.DataFrame(d)
I learned that if each group is plotted separately, I can get what I need.
x1 = data.loc[data['grp'] == 'a', 'x']
y1 = data.loc[data['grp'] == 'a', 'y']
x2 = data.loc[data['grp'] == 'b', 'x']
y2 = data.loc[data['grp'] == 'b', 'y']
x3 = data.loc[data['grp'] == 'c', 'x']
y3 = data.loc[data['grp'] == 'c', 'y']
plt.scatter(x1, y1, label = 'a')
plt.scatter(x2, y2, label = 'b')
plt.scatter(x3, y3, label = 'c')
plt.legend()
plt.show()
This approach seems to be a little inefficient, especially when the number of groups is big. Another approach I found from stackoverflow is the following.
from matplotlib.colors import ListedColormap
values = data['grp'].replace(['a','b','c'], [0, 1, 2])
colors = ListedColormap(['r','b','g'])
scatter = plt.scatter(data['x'], data['y'], c=values, cmap=colors)
plt.legend(handles=scatter.legend_elements()[0], labels=['a', 'b', 'c'])
plt.show()
This is easier since you don't have to do each group separately. However, I don't quite understand handles=scatter.legend_elements()[0]
. Is there an easy and intuitive way of doing this? I have been a R user before and this task can be easily done in ggplot
where everything seems to be handled automatically. Thanks for the help!
You can use the pandas DataFrame.groupby
method to split the groups by the "grp"
column. The plot can be made by looping through groups and plotting them individually.
import pandas as pd
import matplotlib.pyplot as plt
plt.close("all")
d = {"x": [1, 2, 3, 4, 5, 6, 7],
"y": [3, 5, 7, 2, 4, 1, 8],
"grp": ["a","a","b","a","c","b","c"]}
df = pd.DataFrame(d)
fig, ax = plt.subplots()
for group, data in df.groupby("grp"):
ax.scatter(data.x, data.y, color=colors[group], label=group)
ax.legend()
fig.show()
If you want to use the seaborn package, then you can use the scatterplot
function and pass the "grp"
to the hue
argument to color them by group.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
plt.close("all")
d = {"x": [1, 2, 3, 4, 5, 6, 7],
"y": [3, 5, 7, 2, 4, 1, 8],
"grp": ["a","a","b","a","c","b","c"]}
df = pd.DataFrame(d)
colors = {"a":"r", "b":"g", "c":"b"}
fig, ax = plt.subplots()
sns.scatterplot(data=df, x="x", y="y", hue="grp", ax=ax, palette=colors)
fig.show()
If you have a large number of groups, then you probably don't want to specify the colors manually. In that case, don't pass the color
argument to the matplotlib-only version or the palette
to the seaborn version.