Search code examples
pythonmatplotliblegendgeopandascartopy

How to create a single figure legend for GeoAxes subplots


I have looked at the many other questions on here to try and solve this but for whatever reason I cannot. Each solution seems to give me the same error, or returns nothing at all.

I have a list of six dataframes I am looping through to create a figure of 6 maps. Each dataframe is formatted similar with the only differenc being their temporal column. Each map has the same classification scheme created through cartopy. The colors are determined with a colormap, the dataframe itself has no colors related to the values. I want a singular legend for all the maps, so that it is more visible to readers, and less redundant. Here is my code:

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import mapclassify
from matplotlib.colors import rgb2hex
from matplotlib.colors import ListedColormap
plt.style.use('seaborn-v0_8-dark')

# Define the Robinson projection
robinson = ccrs.Robinson()

# Create a 3x2 grid of subplots
fig, axs = plt.subplots(3, 2, figsize=(12, 12), subplot_kw={'projection': robinson})

# Flatten the subplot array for easy iteration
axs = axs.flatten()


# Define color map and how many bins needed
cmap = plt.cm.get_cmap('YlOrRd', 5) #Blues #Greens #PuRd #YlOrRd
# Any countries with NaN values will be colored grey
missing_kwds = dict(color='grey', label='No Data')


# Loop through the dataframes and create submaps
for i, df in enumerate(dataframes):
    
    # Create figure and axis with Robinson projection
    mentionsgdf_robinson = df.to_crs(robinson.proj4_init)
    
   
    # Plot the submap
    ax = axs[i]
    
    # Add land mask and gridlines
    ax.add_feature(cfeature.LAND.with_scale('50m'), facecolor='lightgrey')
    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                  linewidth=1, color='gray', alpha=0.3, linestyle='--')
    
    gl.xlabel_style = {'fontsize': 7}
    gl.ylabel_style = {'fontsize': 7}
    
    # Classification scheme options: EqualInterval, Quantiles, NaturalBreaks, UserDefined etc.
    mentionsgdf_robinson.plot(column='mentions', 
                              ax=ax, 
                              legend=True, #True
                              cmap=cmap, 
                              legend_kwds=({"loc":'center left', 'title': 'Number of Mentions', 'prop': {'size': 7, 'family': 'serif'}}),
                              missing_kwds=missing_kwds,
                              scheme="UserDefined", 
                              classification_kwds = {'bins':[20, 50, 150, 300, 510]})
    
    # Set the titles of each submap
    ax.set_title(f'20{i+15}', size = 15, family = 'Serif')
    
    # Define the bounds of the classification scheme
    upper_bounds = mapclassify.UserDefined(mentionsgdf_robinson.mentions, bins=[20, 50, 150, 300, 510]).bins
    
    
    bounds = []
    for index, upper_bound in enumerate(upper_bounds):
        if index == 0:
            lower_bound = mentionsgdf_robinson.mentions.min()
        else:
            lower_bound = upper_bounds[index-1]

        bound = f'{lower_bound:.0f} - {upper_bound:.0f}'
        bounds.append(bound)
    
    # replace the legend title and increase font size
    legend_title = ax.get_legend().get_title()
    legend_title.set_fontsize(8)
    legend_title.set_family('serif')
    
    
    # get all the legend labels and increase font size
    legend_labels = ax.get_legend().get_texts()
    # replace the legend labels
    for bound, legend_label in zip(bounds, legend_labels):
        legend_label.set_text(bound)
    


fig.suptitle(' Yearly Country Mentions in Online News about Species Threatened by Trade ', fontsize=15, family = 'Serif')
    
# Adjust spacing between subplots
plt.tight_layout(pad=4.0)

# Save the figure
#plt.savefig('figures/submaps_5years.png', dpi=300)

# Show the submap
plt.show()

And here is my result as of right now. I would like to have just one legend somewhere to the side of center of the maps.

enter image description here

I have tried this code as suggested here but only received a UserWarning: Legend does not support handles for PatchCollection instances. Also I didn't know how to possibly incorporate all the other modifications I need for the legend outside of the loop (bounds, font, bins, etc.)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center')

Here's data for three years 2015-2017: https://jmp.sh/s/ohkSJpaMZ4c1GifIX0nu

Here's all the files for the global shapefile that I've used: https://jmp.sh/uTP9DZsC

Using this data and the following code should allow you to run the full visualization code shared above. Thank you.

import geopandas as gpd
import pandas as pd

# Read in globe shapefile dataframe 
world = gpd.read_file("TM_WORLD_BORDERS-0.3.shp")
# Read in sample dataframe
df = pd.read_csv("fifsixseventeen.csv", sep = ";")

# Separate according to date column 
fifteen = df[(df['date'] == 2015)].reset_index(drop=True)
sixteen = df[(df['date'] == 2016)].reset_index(drop=True)
seventeen = df[(df['date'] == 2017)].reset_index(drop=True)

# Function to merge isocodes of the countries with world shapefile 
def merge_isocodes(df):
    
    # Groupby iso3 column in order to merge with shapefile
    allmentions = df.groupby("iso3")['mentions'].sum().sort_values(ascending = False).reset_index()

    # Merge on iso3 code 
    mentionsgdf = pd.merge(allmentions, world, left_on=allmentions["iso3"], right_on=world["ISO3"], how="right").drop(columns = "key_0")

    # Redefine as a geodataframe
    mentionsgdf = gpd.GeoDataFrame(mentionsgdf, geometry='geometry')
    
    return mentionsgdf

onefive = merge_isocodes(fifteen)
onesix = merge_isocodes(sixteen)
oneseven = merge_isocodes(seventeen)

# Create a list to store each years' dataframes
dataframes = [onefive, onesix, oneseven]

Solution

  • ... 
    
    for i, df in enumerate(dataframes):
        ...
    
    # after the for loop, use the following code
    
    # extract the legend from an axes - used the last axes for the smaller sample data
    l = axs[2].get_legend()
    # extract the handles
    handles = l.legend_handles
    # get the label text
    labels = [v.get_text() for v in l.texts]
    # get the title text
    title = l.get_title().get_text()
    # create the figure legend
    fig.legend(title=title, handles=handles, labels=labels, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
    
    # iterate through each Axes
    for ax in axs:
        # if the legend isn't None (if this condition isn't required, remove it and use only ax.get_legend().remove()) 
        if gt := ax.get_legend():
            # remove the legend
            gt.remove()
    

    enter image description here