Search code examples
python-3.xpandasmatplotliblabelscatter-plot

How to label data points in matplotlib scatter plot while looping through pandas dataframes?


I have a pandas dataframe including the following columns:

label = ('A' , 'D' , 'K', 'L', 'P')
x = (1 , 4 , 9, 6, 4)
y = (2 , 6 , 5, 8, 9)
plot_id = (1 , 1 , 2, 2, 3)

I want to creat 3 seperate scatter plots - one for each individual plot_id. So the first scatter plot should consists all entries where plot_id == 1 and hence the points (1,2) and (4,6). Each data point should be labelled by label. Hence the first plot should have the labels Aand B.

I understand I can use annotate to label, and I am familiar with for loops. But I have no idea how to combine the two.

I wish I could post better code snippet of what I have done so far - but it's just terrible. Here it is:

for i in range(len(df.plot_id)):
    plt.scatter(df.x[i],df.y[i])
    plt.show()

That's all I got - unfortunately. Any ideas on how to procede?


Solution

  • updated answer
    save separate image files

    def annotate(row, ax):
        ax.annotate(row.label, (row.x, row.y),
                    xytext=(10, -5), textcoords='offset points')
    
    for pid, grp in df.groupby('plot_id'):
        ax = grp.plot.scatter('x', 'y')
        grp.apply(annotate, ax=ax, axis=1)
        plt.savefig('{}.png'.format(pid))
        plt.close()
    

    1.png
    enter image description here

    2.png
    enter image description here

    3.png
    enter image description here

    old answer
    for those who want something like this

    def annotate(row, ax):
        ax.annotate(row.label, (row.x, row.y),
                    xytext=(10, -5), textcoords='offset points')
    
    fig, axes = plt.subplots(df.plot_id.nunique(), 1)
    for i, (pid, grp) in enumerate(df.groupby('plot_id')):
        ax = axes[i]
        grp.plot.scatter('x', 'y', ax=ax)
        grp.apply(annotate, ax=ax, axis=1)
    fig.tight_layout()
    

    enter image description here

    setup

    label = ('A' , 'D' , 'K', 'L', 'P')
    x = (1 , 4 , 9, 6, 4)
    y = (2 , 6 , 5, 8, 9)
    plot_id = (1 , 1 , 2, 2, 3)
    
    df = pd.DataFrame(dict(label=label, x=x, y=y, plot_id=plot_id))