Search code examples
pythonmatplotliblegendlegend-properties

How can I create a custom arrow-shaped legend key


I'm creating a plot using Matplotlib. I want the legend keys (the shapes next to the descriptive text in the legend) to be arrows similar to the ones present in the figure. I've been able to get an output that is close to what I want as you can see in the picture below. There are arrows in the legend but they are incredibly long and narrow.

The current iteration of my plot. The legend key arrows are present but tiny.

The code I used to create the plot is here:

# Define custom legend handler
class ArrowHandler(HandlerBase):
    def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
        arrow = FancyArrowPatch((0, 5), (40, 5), arrowstyle='->', color=orig_handle.get_edgecolor(), mutation_aspect=20)
        return [arrow]

# Plot the river profile using Matplotlib
plt.figure(figsize=(11, 6))
plt.plot(creek_long_profile_dataframes['Foley Creek']['distance_upstream_km'],
         creek_long_profile_dataframes['Foley Creek']['elevation_m'],
         color='#368bb5', linewidth=2)

### Add sediment delivery points as colored arrows
delivery_data = creek_sed_delivery_dataframes["Foley Creek"]

# Get range of x and y values to be able to standardize the size of the arrows
x_min_max = creek_long_profile_dataframes['Foley Creek']['distance_upstream_km'].agg(['min', 'max'])
x_range = x_min_max[1] - x_min_max[0]
y_min_max = creek_long_profile_dataframes['Foley Creek']['elevation_m'].agg(['min', 'max'])
y_range = y_min_max[1] - y_min_max[0]

# Add arrows to the figure
for index, row in delivery_data.iterrows():
    x = row['distance_upstream_km']
    y = row['elevation_m']
    color = '#1b9e77' if row['mass_movement_type'] == 'Debris Flow' else '#7570b3'
    plt.arrow(x, y+(0.098*y_range), 0, -(0.088*y_range), width=0.006*x_range, head_width=0.015*x_range,
              head_length=0.029*y_range, color=color, length_includes_head=True)

### done adding sediment delivery arrows

# Set axis labels
plt.xlabel('Distance upstream (km)')
plt.ylabel('Elevation (m)')

# Set title
plt.title('Sediment Delivery Along Foley Creek')

# Add grid
plt.grid(which='major', axis='both')

# Define the legend elements
legend_elements = [
    FancyArrowPatch((0, 0), (0, 0), arrowstyle='->', color='#1b9e77', label='Debris Flow'),
    FancyArrowPatch((0, 0), (0, 0), arrowstyle='->', color='#7570b3', label='Debris Avalanche')
]

# Create the legend with custom handler
plt.legend(handles=legend_elements, handler_map={FancyArrowPatch: ArrowHandler()})

# Save figure
# plt.savefig('sed_delivery_figs/sed_delivery_foley.pdf', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()

I first tried passing a list of Arrow instances similar to the code below.

# Define the legend elements
legend_elements = [
    FancyArrowPatch((0, 0), (0.5, 0), arrowstyle='->', color='#1b9e77', label='Debris Flow'),
    FancyArrowPatch((0, 0), (0.5, 0), arrowstyle='->', color='#7570b3', label='Debris Avalanche')
]

# Add the legend
plt(handles=legend_elements, loc='center')

This resulted in legend keys which were rectangles (not the arrows I'm hoping for) but which were the correct color.

I hobbled together some code to create a custom legend handler but I unfortunately have very little understanding of how they work and how I should be defining it. Any help would be appreciated.

I've taken a look at some related posts (here and here) but I haven't been able to implement their suggestions in my use case because I'm having trouble wrapping my head around some aspects of Matplotlib.

Does anyone know how I can put arrows in the legend like the ones that I plotted in the figure itself? Is there a simpler approach?


Solution

  • Good job laying out your question clearly, with relevant supporting details.

    I cannot say I understand how this all works. But I can say that using FancyArrow (from matplotlib.patches) instead of FancyArrowPatch worked for me. Can you try something like this?

    import matplotlib.patches as mpatches
    for creek in creek_long_profile_dataframes:
        print(creek)
        # Define custom legend handler
        class ArrowHandler(HandlerBase):
            def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
                arrow = mpatches.FancyArrow(0, 3, 25, 0, color=orig_handle.get_edgecolor(), width=2.5, length_includes_head=False)
                return [arrow]
    

    Hopefully you'll get nice arrows with clearly discernible arrow-shaped heads in your legend. (I got that idea from https://matplotlib.org/stable/gallery/shapes_and_collections/arrow_guide.html)

    legend with nice arrows