Search code examples
pythonmatplotlibplotscatter-plotlegend-properties

How to customize a scatter plot legend with number of observations as labels


I would like to add outside the figure, a box with a text where I just say the number of positive or negative values. The text for each type has to have the same color that the data in the plot, so for positive it has to be red and for negative, it has to be blue.

Here is the code that I have written :

text_plot = (f"number of positive neta : {nb_pos_neta}\nnumber of negative neta : {nb_neg_neta}")

fig, ax = plt.subplots(figsize =(10,7))
ax.scatter(time_det, neta, c = np.sign(neta), cmap="bwr", s=4, label='Rapport of polarisation')
plt.title('Evolution of rapport of polarisation - Aluminium')
plt.xlabel('Time [min]')
plt.ylabel('Rapport [-]')
plt.figtext(1.05, 0.5, text_plot, ha="right", fontsize=10, bbox={"facecolor":"white","alpha":0.5, "pad":5})
plt.tight_layout()
plt.savefig("Evolution of rapport of polarisation - (Aluminium).png")
plt.show()

And here is the result :

enter image description here


Solution

  • The trick here is to use matplotlib's patches to get the custom legend. Since I needed to generate some fake data to make things look close (and didn't really want to delve into things like neta and time_det since they aren't central to your question) I refactored using numpy's where and size for the coloring and counting of the dots.

    import random
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    
    
    # generate some fake data of a similar range
    x = np.random.random(100)*3000
    y = np.random.random(100)*1
    
    count_red = np.size(np.where(np.reshape(y,-1) >= .5))
    count_blue = np.size(np.where(np.reshape(y,-1)< .5))
    
    col = np.where(x<0,'k',np.where(y<.5,'b','r'))
    
    fig, ax = plt.subplots(figsize =(10,7))
    
    red_patch = mpatches.Patch(color='red', label=count_red)
    blue_patch = mpatches.Patch(color='blue', label=count_blue)
    
    dist_off_right_spline = .95
    dist_from_top_spline  = .6
    
    plt.title('Evolution of rapport of polarisation - Aluminium')
    plt.xlabel('Time [min]')
    plt.ylabel('Rapport [-]')
    plt.tight_layout()
    plt.savefig("Evolution of rapport of polarisation - (Aluminium).png")
    
    plt.legend(bbox_to_anchor=(dist_off_right_spline, dist_from_top_spline), 
               loc='upper left', handles=[red_patch, blue_patch])
    
    plt.scatter(x, y, c=col, s=5, linewidth=1)
    plt.show()
    

    And that (minus the y-axis range) gives you an image pretty close to what you specified.

    enter image description here