Search code examples
pythonmatplotlibscatter-plot

Annotate a scatterplot with text and position taken from a pandas dataframe directly


What I want to achieve is a more elegant and direct method of annotating points with x and y position from a pandas dataframe with a corresponding label from the same row.

This working example works and results in what I want, but I feel there must be a more elegant solution out there without having to store individual columns in separate lists first and having to iterate over them.

My concern is that having these separate lists could results in misalignment of labels with data in cases of larger and complicated datasets with missing values, nans, etc.

In this example, x = Temperature, y = Sales and the label is the Date.

import pandas as pd
import matplotlib.pyplot as plt

d = {'Date': ['15-08-24', '16-08-24', '17-08-24'], 'Temperature': [24, 26, 20], 'Sales': [100, 150, 90]}
df = pd.DataFrame(data=d)

Which gives:

       Date  Temperature  Sales
0  15-08-24           24    100
1  16-08-24           26    150
2  17-08-24           20     90

Then:

temperature_list = df['Temperature'].tolist()
sales_list = df['Sales'].tolist()
labels_list = df['Date'].tolist()

fig, axs = plt.subplots()
axs.scatter(data=df, x='Temperature', y='Sales')
for i, label in enumerate(labels_list):
    axs.annotate(label, (temperature_list[i], sales_list[i]))
plt.show()

What I aim for - but does not work - is something along the lines of:

fig, axs = plt.subplots()
axs.scatter(data=df, x='Temperature', y='Sales')
axs.annotate(data=df, x='Temperature', y='Sales', text='Date') # this is invalid
plt.show()

Suggestions welcome. If there is no way around the iterative process, perhaps there is at least a fail-safe method to warrant correct attribution of labels to corresponding data points.


Solution

  • You probably can't avoid the iteration, but you can remove the need to create lists by using df.iterrows(). This has the added benefit that you are not decoupling any data from your DataFrame.

    import pandas as pd
    import matplotlib.pyplot as plt
    
    d = {'Date': ['15-08-24', '16-08-24', '17-08-24'], 'Temperature': [24, 26, 20], 'Sales': [100, 150, 90]}
    df = pd.DataFrame(data=d)
    
    fig, axs = plt.subplots()
    axs.scatter(data=df, x='Temperature', y='Sales')
    
    for i, row in df.iterrows():
        axs.annotate(row["Date"], (row["Temperature"], row["Sales"]))
        
    plt.show()
    

    enter image description here