Search code examples
pythonmatplotliblegend-properties

Create a rectangular patch with upper and lower edge in matplotlib


I have the following plot in matplotlib:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

x=np.linspace(0,2.,100)
y1=np.power(x,2)*2.+x
y2=np.power(x,2)*2-0.2

plt.plot(x,y1,color="k")
plt.plot(x,y2,color="k")

plt.fill_between(x,y1,y2,facecolor=(0,0,0,0.3),lw=0)

and I want to add a legend similar to this:

legend_elements = [Patch(facecolor=(0,0,0,0.4), edgecolor='k',
                         label='Filled area')]
plt.gca().legend(handles=legend_elements)

Which produces this:

Plot with filled area

However, I would like to remove the lateral edges from the patch in the legend, to match what I see in the plot (i.e., I only have the upper and lower edge).

I searched how to draw only some edges of a matplotlib patch but did not find anything. Is something like this possible at all?


Solution

  • You can make your own legend handler by extending the built-in HandlerPolyCollection which is used for fill_between:

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    
    x = np.linspace(0, 2., 100)
    y1 = np.power(x, 2) * 2 + x
    y2 = np.power(x, 2) * 2 - 0.2
    
    plt.plot(x, y1, color="k")
    plt.plot(x, y2, color="k")
    
    plt.fill_between(x, y1, y2, fc=(0, 0, 0, 0.3), lw=0, label="Filled area")
    
    class HandlerFilledBetween(mpl.legend_handler.HandlerPolyCollection):
        def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
            p = super().create_artists(legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans)[0]
            x0, y0 = p.get_x(), p.get_y()
            x1 = x0 + p.get_width()
            y1 = y0 + p.get_height()
            line_upper = mpl.lines.Line2D([x0, x1], [y1, y1], color='k')
            line_lower = mpl.lines.Line2D([x0, x1], [y0, y0], color='k')
            return [p, line_upper, line_lower]
        
    plt.gca().legend(handler_map={mpl.collections.PolyCollection: HandlerFilledBetween()})
    

    enter image description here