Search code examples
pythonpandasmatplotlibseabornlegend

Add legend label for each row in a pandas scatter plot


Have table as:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

list_1=[['AU',152,474.0],
        ['CA',440,482.0],
       ['DE',250,564.0,],
       ['ES',707,549.0,],
       ['FR',1435,551.0,],
       ['GB',731,555.0,],
       ['IT',979,600.0,],
       ['NDF',45041,357.0,],
       ['NL',247,542.0,],
       ['PT',83,462.0,],
       ['US',20095,513.0,],
       ['other',3655,526.0,]]
labels=['country_destination','num_users','avg_hours_spend']
df=pd.DataFrame(list_1,columns=labels)
df=df.set_index('country_destination')
df
country_destination num_users   avg_hours_spend 
AU                     152        474.0
CA                     440        482.0
DE                     250        564.0
ES                     707        549.0
FR                     1435       551.0
GB                     731        555.0
IT                     979        600.0
NDF                    45041      357.0
NL                     247        542.0
PT                     83         462.0
US                     20095      513.0
other                  3655       526.0

I need to make scatter plot:

y = df['avg_hours_spend']
x = df['num_users']
N=12
colors = np.random.rand(N)
plt.scatter(x, y,c=colors)

plt.title('Web Sessions Data of Users')
plt.xlabel('No.Of.Users')
plt.ylabel('Mean Hours Users Spends on the Website')
plt.legend()
plt.show()

Scatter plot where each color is different country

enter image description here

Needed: I want to make big circles and add legend in the right side when for each country will be different color. How ?


Solution

  • In matplotlib, you can add a different scatter point for each country (i.e. each level of your dataframe's index), and set the s argument to whatever you want (since you want larger points, I added s=100:

    for i, row in df.iterrows():
        plt.scatter(x=row.num_users, y=row.avg_hours_spend, label=i, s=100)
    
    plt.title("Web Sessions Data of Users")
    plt.xlabel("No.Of.Users")
    plt.ylabel("Mean Hours Users Spends on the Website")
    plt.legend()
    plt.show()
    

    enter image description here

    You can achieve a similar result with a different syntax with seaborn:

    import seaborn as sns
    
    ax = sns.scatterplot(
        x="num_users",
        y="avg_hours_spend",
        hue="country_destination",
        s=100,
        data=df.reset_index(),
    )
    
    ax.set_title("Web Sessions Data of Users")
    ax.set_xlabel("No.Of.Users")
    ax.set_ylabel("Mean Hours Users Spends on the Website")
    

    enter image description here