Search code examples
matplotlibplotlegendpython-xarrayfacet

choosing artists/handles for legend when plotting from xarray.dataArray


I am plotting a facet plot from a xarray DataArray with multiple lines on each plot. I want to have a legend for each plot but it should only contain certain lines. I use .legend() where the data I want to have in the legend should be given but as I am plotting from an xarray DataArray I don't know how to do that.

Here as an example code:

import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
import pandas as pd
           
data1 = np.random.randn(4, 4,3)
loc= np.linspace(1,3,3)
type = ["a", "b", "c", "d"]
            
times = pd.date_range("2000-01-01", periods=4)
foo = xr.DataArray(data1, coords=[times, type, loc], dims=["time", "type","loc"])
    
t=foo.plot.line(x="time", col="loc", linewidth= 5, col_wrap=3)
for i, ax in enumerate(t.axes.flat):
    ax.legend(('labelc','labeld'))

enter image description here

Here, I would want the labels to actually fit the data c and d

ax.legend is supposed to work like this:

ax.legend([line1, line2, line3], ['label1', 'label2', 'label3'])

I have already tried it like in the following but could not get it to work:

for i, ax in enumerate(t.axes.flat):
    ax.legend(foo[:,2:,i],('labelc','labeld'))

Solution

  • We can retrieve the figure-level legend entries and filter them for the desired entries:

    import matplotlib.pyplot as plt
    import xarray as xr
    import numpy as np
    import pandas as pd
               
    data1 = np.random.randn(4, 4, 3)
    locs = np.linspace(1, 3, 3)
    types = ["a", "b", "c", "d"]
                
    times = pd.date_range("2000-01-01", periods=4)
    foo = xr.DataArray(data1, coords=[times, types, locs], dims=["time", "type","loc"])
        
    t=foo.plot.line(x="time", col="loc", linewidth= 5, col_wrap=3)
    
    #define list which entries to show in subplot legends 
    to_keep = ["b", "d"]
    #retrieve global handles and labels
    labels = t._hue_var.to_numpy()
    handles = t._mappables[-1]
    
    #create a dictionary of labels and handles while filtering unwanted entries
    label_dic = {l:h  for l, h in zip(labels, handles) if l in to_keep}
    
    #create legend entries as you intended with your code
    for ax in t.axes.flat:
        ax.legend(label_dic.values(), label_dic.keys(), loc="best")
    plt.show()
    

    Sample output: enter image description here

    The code to retrieve the legend entries is lifted from the source code for def add_legend().

    Please also note that I have changed your variable type to types as it shadowed the Python function type().