Search code examples
pythonmatplotlibcartopy

Why does adding a legend in this cartopy projection figure significantly increase the execution time (and how can fix this)?


So I have written a script for making figures that worked alright (see mock-up example below). But when I add a legend to the figure, the execution time increases a lot. I don't really understand what is happening here, I would naively expect that simply adding a legend is not a complex thing.

I suspect this has something to do with the cartopy projection, since it works alright if I don't use this.

What is the problem here, and how can I avoid this?

Problematic code:

import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

# Mockup dataset
num = 300
lat = np.linspace(-54,-59,num=num)
lon = np.linspace(-5,5, num=num)

data = np.outer(lat,lon)

ds = xr.DataArray(data=data,
                 dims=["lat", "lon"],
                 coords=dict(lon=lon, lat=lat))


# Map projection
map_proj = ccrs.SouthPolarStereo()

ax = plt.axes(projection=map_proj)
ax.gridlines(draw_labels=True)
ax.set_extent([-3,4,-58,-54])

# Plot image
ds.plot(cmap="gray", 
        add_colorbar=False,
        transform=ccrs.PlateCarree(), # data projection
        subplot_kws={'projection': map_proj}) # map projection

# Plot contours
cs = ds.plot.contour(transform=ccrs.PlateCarree())

# Make legend
proxy = [matplotlib.lines.Line2D([],[], c=pc.get_color()[0]) for pc in cs.collections]
labels = ["foo"] * len(cs.collections)
plt.legend(proxy, labels)

Code without cartopy projection:

import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

# Mockup dataset
num = 300
lat = np.linspace(-54,-59,num=num)
lon = np.linspace(-5,5, num=num)

data = np.outer(lat,lon)

ds = xr.DataArray(data=data,
                 dims=["lat", "lon"],
                 coords=dict(lon=lon, lat=lat))

# Plot image
ds.plot(cmap="gray", 
        add_colorbar=False) # map projection

# Plot contours
cs = ds.plot.contour()

# Make legend
proxy = [matplotlib.lines.Line2D([],[], c=pc.get_color()[0]) for pc in cs.collections]
plt.legend(proxy, labels)

Solution

  • plt.legend(proxy, labels) defaults to loc='best' which uses an algorithm that can be slow if you have lots of data in your axes, and particularly slow if that data also has complicated transforms. Instead do ax.legend(proxy, labels, loc='upper right') manually. See https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html