Search code examples
pythonpandasdataframematplotlib

How to add legend to df.plot/legend not showing up df.plot()


I am currently creating a scatter plot with the results of some evaluation I am doing.

To get a dataframe of the same structure as mine you can run:

import pandas as pd

models = ["60000_25_6", "60000_26_6"]

results = []
for i in range(10):
    for model in models:
        results.append({"simulation": i, "model_id": model, "count_at_1": 1, "count_at_5": 5, "count_at_10": 10})
df = pd.DataFrame(results)

You will end up with a pandas dataframe that looks like so, just with default values (this is a smaller dataframe, note that the size is variable and much larger depending on the settings I use):

    simulation    model_id  count_at_1  count_at_5  count_at_10
0            0  60000_25_6          60          77           84
1            0  60000_26_6          60          76           83
2            1  60000_25_6          69          80           82
...
18           9  60000_25_6           1          70           79
19           9  60000_26_6           1          68           74

I then use the following code to add colors to each point:

import matplotlib.pyplot as plt

colors = plt.get_cmap('hsv')
colors = [colors(i) for i in np.linspace(0,0.95, len(models))]
cmap = {model: colors[i] for i, model in enumerate(models)}

df['color'] = df.apply(lambda row: cmap[row['model_id']], axis=1)

And df is now:

    simulation    model_id  count_at_1  count_at_5  count_at_10                 color
0            0  60000_25_6          74          81           83  (1.0, 0.0, 0.0, 1.0)
1            0  60000_26_6          75          80           83  (1.0, 0.0, 0.5, 1.0)
2            1  60000_25_6          71          84           89  (1.0, 0.0, 0.0, 1.0)
...
18           9  60000_25_6           2          69           79  (1.0, 0.0, 0.0, 1.0)
19           9  60000_26_6           2          72           78  (1.0, 0.0, 0.5, 1.0)

However when I run:

df.plot.scatter('count_at_1', 'count_at_5', c='color', legend=True)

plt.show()

No legend appears I just get a normal plot like so: plot with no legend

How can I add a legend where it looks something like:

[model_id]  [color]
...

But in the normal matplotlib format, I'll take it anywhere on the plot.


Solution

  • The data need assigned a label for the legend, so one option is:

    fig, ax = plt.subplots()
    for model in df['model_id'].unique():
        df[df["model_id"].eq(model)].plot.scatter('count_at_1', 'count_at_5', 
                                                  c='color', label=model,
                                                  ax=ax)
    ax.legend(markerfirst=False)
    

    With your sample data:

    enter image description here