Search code examples
pythonmatplotliblegend

Custom legend for the plot with lines changing colour


I want to plot two error bars plots with lines changing colour, one going from pink to blue, and another from blue to pink. I did not find a way to do this with plt.errorbar(), but managed to find a workaround solution using LineCollection. I am plotting error bar plots with a dashed line, and then adding two lines that change colour using LineCollection. Each line is separated into segments coloured according to the range of colours from the chosen colormap.

Now the problem is to create a legend for this plot. I am using a custom legend, but cannot figure out a way to have lines changing colour in the legend. Is there a way to do so?

Here is a code to make the plot:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.lines import Line2D

fig, axs = plt.subplots(1, 1, figsize=(10 , 8))

T = 9
x = np.linspace(0, T-1, T)
data = np.random.rand(T, 30)

# plot with error bars
Y_mean = np.mean(data, axis=1)     # mean for every row
Y_std = np.std(data, axis=1, ddof=1)   # std
plt.errorbar(x, Y_mean, yerr=Y_std, capsize=6, elinewidth=4, ecolor = "grey",linestyle = 'dotted',color='black')

# plot a line going from blue to pink
y = Y_mean
y_col = np.linspace(-10, 10, 8)
# Create a set of line segments so that we can color them individually
# This creates the points as an N x 1 x 2 array so that we can stack points
# together easily to get the segments. The segments array for line collection
# needs to be (numlines) x (points per line) x 2 (for x and y)
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

# Create a continuous norm to map from data points to colors
norm_cmap = plt.Normalize(y_col.min(), y_col.max())
lc = LineCollection(segments, cmap='cool', norm=norm_cmap)
# Set the values used for colormapping
lc.set_array(y_col)
lc.set_linewidth(4)
line = axs.add_collection(lc)

# plot the second plot with error bars
data = np.random.rand(T, 30) + 1
Y_mean = np.mean(data, axis=1)     # mean
Y_std = np.std(data, axis=1, ddof=1)   # std
plt.errorbar(x, Y_mean, yerr=Y_std, capsize=6, elinewidth=4, ecolor = "grey",linestyle = 'dotted',color='black')

# plot the second line
y = Y_mean
y_col = np.linspace(10, -10, 8) 
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
norm_cmap = plt.Normalize(y_col.min(), y_col.max())
lc = LineCollection(segments, cmap='cool', norm=norm_cmap)
lc.set_array(y_col)
lc.set_linewidth(4)
line = axs.add_collection(lc)

fig.suptitle("example plot")    

legend_elements = [Line2D([0], [0], color='black', linestyle = 'dotted', alpha = 1, label = r'line1', lw=4),
                  Line2D([0], [0], color='black', linestyle = 'dotted', alpha = 1, label = r'line2', lw=4)]
axs.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc="upper left", title_fontsize = 18);

Here is the plot: Example plot

I tried adding a custom legend using "colour patches" plotted with matplotlib.lines.Line2D, but it seems like it's only possible to specify one colour per patch.


Solution

  • this can be done by slightly adjusting the answer I posted to a similar question concerning custom legend handles for multi-color points (https://stackoverflow.com/a/67870930/9703451)

    The idea is to use a custom handler that will draw the legend-patch... this should do the job:

    # define an object that will be used by the legend
    class MulticolorPatch(object):
        def __init__(self, cmap, ncolors=100):
            self.ncolors = ncolors
            
            if isinstance(cmap, str):
                self.cmap = plt.get_cmap(cmap)
            else:    
                self.cmap = cmap
            
    # define a handler for the MulticolorPatch object
    class MulticolorPatchHandler(object):
        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            n = orig_handle.ncolors
            width, height = handlebox.width, handlebox.height
            patches = []
            for i, c in enumerate(orig_handle.cmap(i/n) for i in range(n)):
                patches.append(
                    plt.Rectangle([width / n * i - handlebox.xdescent, 
                                             -handlebox.ydescent],
                                  width / n,
                                  height, 
                                  facecolor=c, 
                                  edgecolor='none'))
    
            patch = PatchCollection(patches,match_original=True)
    
            handlebox.add_artist(patch)
            return patch
    
    # ------ create the legend
    handles = [
        MulticolorPatch("cool"), 
        MulticolorPatch("cool_r"), 
        MulticolorPatch("viridis")
        ]
    labels = [
        "a cool line", 
        "a reversed cool line", 
        "a viridis line"
        ]
    
    # ------ create the legend
    fig.legend(handles, labels, 
               loc='upper left', 
               handler_map={MulticolorPatch: MulticolorPatchHandler()}, 
               bbox_to_anchor=(.125,.875))
    
    

    enter image description here