Search code examples
pythonpandasmatplotlibvisualizationscatter-plot

How to scatter plot each group of a pandas DataFrame


I am making a scatter plot with the geyser dataset from seaborn. I am coloring the points based on the 'kind' column but for some reason, the legend only shows 'long' but leaves out 'short'. I don't know what I am missing. I also was wondering if there is a simpler way to color code the data one that does not use a for-loop.

geyser_df.head()

     duration  waiting   kind
0       3.600       79   long
1       1.800       54  short
2       3.333       74   long
3       2.283       62  short
4       4.533       85   long
x = geyser_df['waiting']
y = geyser_df['duration']
col = []

for i in range(len(geyser_df)):
    if (geyser_df['kind'][i] == 'short'):
        col.append('MediumVioletRed')
    elif(geyser_df['kind'][i] == 'long'):
        col.append('Navy')

plt.scatter(x, y, c=col)
plt.legend(('long','short'))
plt.xlabel('Waiting')
plt.ylabel("Duration")
plt.suptitle("Waiting vs Duration")
plt.show()

enter image description here


Solution

  • import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    # load data
    df = sns.load_dataset('geyser')
    
    # plot
    fig, ax = plt.subplots(figsize=(6, 4))
    colors = {'short': 'MediumVioletRed', 'long': 'Navy'}
    for kind, data in df.groupby('kind'):
        data.plot(kind='scatter', x='waiting', y='duration', label=kind, color=colors[kind], ax=ax)
    
    ax.set(xlabel='Waiting', ylabel='Duration')
    fig.suptitle('Waiting vs Duration')
    plt.show()
    

    enter image description here

    • The easiest way is with seaborn, a high-level API for matplotlib, where hue is used to separate groups by color.
    fig, ax = plt.subplots(figsize=(6, 4))
    colors = {'short': 'MediumVioletRed', 'long': 'Navy'}
    sns.scatterplot(data=df, x='waiting', y='duration', hue='kind', palette=colors, ax=ax)
    
    ax.set(xlabel='Waiting', ylabel='Duration')
    fig.suptitle('Waiting vs Duration')
    plt.show()
    
    colors = {'short': 'MediumVioletRed', 'long': 'Navy'}
    p = sns.relplot(data=df, x='waiting', y='duration', hue='kind', palette=colors, height=4, aspect=1.5)
    
    ax = p.axes.flat[0]  # extract the single subplot axes
    
    ax.set(xlabel='Waiting', ylabel='Duration')
    p.fig.suptitle('Waiting vs Duration', y=1.1)
    plt.show()