Search code examples
pythonmatplotliblegendlegend-properties

How to customize the handles and labels for a legend


I am trying to customize the ax from matplotlib plot. Here I am using a surpyval package to fit the data and then plot it. The plot method in surpyval package does not accept arguments other than the ax=ax as I provided. My problem is i can't match the handles and legends as you can see from this example:

import surpyval as surv 
import matplotlib.pyplot as plt 

y_a = np. array([181, 183, 190,190, 195, 195, 198, 198, 198, 201,202, 202, 202, 
       204, 205, 205, 206,206, 206, 206,207, 209 , 213, 214, 218, 219])


y_s = np.array([161, 179, 196,196, 197, 198, 204, 205, 209, 211,215, 218, 227, 
       230, 231, 232, 232 ,236, 237, 237,240, 243, 244, 246, 252, 255])

model_1 = surv.Weibull.fit(y_a)
model_2 = surv.Weibull.fit(y_s)

ax=plt.gca()

model_1.plot(ax=ax)
model_2.plot(ax=ax)

p_a = ['A', 'a_u_CI','a_l_CI', 'a_fit']
p_s= ['S', 's_u_CI','s_l_CI', 's_fit']
p_t = p_a + p_s

ax.legend(labels=p_t[0:5:4])

enter image description here


Solution

  • As well as specifying your labels, you also need to tell ax.legend which "handles" to use.

    Normally with matplotlib you could label each line or collection as you plot them, but with surpyval this is not possible.

    Instead, we can find the relevant handles from the matplotlib Axes instance.

    In this case, assuming you want to label the blue and orange circles, you want the two PathCollection instances, which are created using matplotlib's scatter method under the hood. Fortunately, these are easy to access, in the ax.collections attribute. So, you just need to change the legend line to:

    ax.legend(handles=ax.collections, labels=p_t[0:5:4])
    

    to get this plot:

    enter image description here

    For reference, the solid red lines, and dashed black lines, are stored in ax.lines. For this plot, there are 6 Line2D objects stored in there, corresponding to the 4 red lines and 2 black lines on the plot.

    In answer to the comment below ("Do you know how to turn the solid lines into fill_between?"), to fill between two of the lines, you can grab the x- and y-data from the two lines, and then use matplotlib's fill_between function.

    For example, to fill between the two red lines for the "A" data, you could do this:

    l1 = ax.lines[0]
    l2 = ax.lines[1]
    
    x1 = l1.get_xdata()
    x2 = l2.get_xdata()
    
    y1 = l1.get_ydata()
    y2 = l2.get_ydata()
    
    ax.fill_between(x1, y1, y2, zorder=-5, alpha=0.5, color='r')
    

    Note that you need to set the zorder to something lower than the scatter points, otherwise it'll hide them.

    For the other two lines in this plot, you would want to use ax.lines[3] and ax.lines[4]


    Note this is all tested with matplotlib v3.6, surpyval v0.10.10, python 3.9.

    I found that using a later release of matplotlib (v3.7.2, Sept 2023) did not work, due to some issues with the way surpyval tries to plot the grid.