I am trying to customize the ax from matplotlib plot. Here I am using a surpyval package to fit the data and then plot it. The plot method in surpyval package does not accept arguments other than the ax=ax as I provided. My problem is i can't match the handles and legends as you can see from this example:
import surpyval as surv
import matplotlib.pyplot as plt
y_a = np. array([181, 183, 190,190, 195, 195, 198, 198, 198, 201,202, 202, 202,
204, 205, 205, 206,206, 206, 206,207, 209 , 213, 214, 218, 219])
y_s = np.array([161, 179, 196,196, 197, 198, 204, 205, 209, 211,215, 218, 227,
230, 231, 232, 232 ,236, 237, 237,240, 243, 244, 246, 252, 255])
model_1 = surv.Weibull.fit(y_a)
model_2 = surv.Weibull.fit(y_s)
ax=plt.gca()
model_1.plot(ax=ax)
model_2.plot(ax=ax)
p_a = ['A', 'a_u_CI','a_l_CI', 'a_fit']
p_s= ['S', 's_u_CI','s_l_CI', 's_fit']
p_t = p_a + p_s
ax.legend(labels=p_t[0:5:4])
As well as specifying your labels, you also need to tell ax.legend
which "handles" to use.
Normally with matplotlib you could label each line or collection as you plot them, but with surpyval
this is not possible.
Instead, we can find the relevant handles from the matplotlib Axes
instance.
In this case, assuming you want to label the blue and orange circles, you want the two PathCollection
instances, which are created using matplotlib's scatter
method under the hood. Fortunately, these are easy to access, in the ax.collections
attribute. So, you just need to change the legend line to:
ax.legend(handles=ax.collections, labels=p_t[0:5:4])
to get this plot:
For reference, the solid red lines, and dashed black lines, are stored in ax.lines
. For this plot, there are 6 Line2D
objects stored in there, corresponding to the 4 red lines and 2 black lines on the plot.
In answer to the comment below ("Do you know how to turn the solid lines into fill_between?"), to fill between two of the lines, you can grab the x- and y-data from the two lines, and then use matplotlib's fill_between function.
For example, to fill between the two red lines for the "A" data, you could do this:
l1 = ax.lines[0]
l2 = ax.lines[1]
x1 = l1.get_xdata()
x2 = l2.get_xdata()
y1 = l1.get_ydata()
y2 = l2.get_ydata()
ax.fill_between(x1, y1, y2, zorder=-5, alpha=0.5, color='r')
Note that you need to set the zorder to something lower than the scatter points, otherwise it'll hide them.
For the other two lines in this plot, you would want to use ax.lines[3]
and ax.lines[4]
Note this is all tested with matplotlib
v3.6, surpyval
v0.10.10, python
3.9.
I found that using a later release of matplotlib (v3.7.2, Sept 2023) did not work, due to some issues with the way surpyval tries to plot the grid.