Search code examples
pythonmatplotliblegendspatialgeopandas

How can I add a legend while plotting multiple geopandas dataframes in the same subplot?


I have a geopandas dataframe world which I created using:

import geopandas as gpd

world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))

I created two different geodataframes for usa and china as shown below:

usa = world[world.name == "United States of America"]

china = world[world.name == "China"]

I want to plot the USA as blue and China as red in the map. I plotted it using the following line of code:

fig, ax = plt.subplots(figsize = (20, 8))
world.plot(ax = ax, color = "whitesmoke", ec = "black")
usa.plot(ax = ax, color = "blue", label = "USA")
china.plot(ax = ax, color = "red", label = "China")
ax.legend()
plt.show()

It looks as follows: enter image description here

I want to add legends stating blue for the USA and red for China. Therefore, I gave labels as shown in the code above. However, I get the following warning:

No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.

I am not able to add the legend. How can I add the legends for the USA and China in this plot? Is it possible using geopandas and matplotlib?


Solution

  • I never used geopandas, however looking at the result is appears that those filled areas are PathCollection, which are not supported on legends. But we can create legend artists:

    import geopandas as gpd
    from matplotlib.lines import Line2D
    
    world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
    usa = world[world.name == "United States of America"]
    china = world[world.name == "China"]
    
    fig, ax = plt.subplots()
    world.plot(ax = ax, color = "whitesmoke", ec = "black")
    usa.plot(ax = ax, color = "blue", label = "USA")
    china.plot(ax = ax, color = "red", label = "China")
    
    lines = [
        Line2D([0], [0], linestyle="none", marker="s", markersize=10, markerfacecolor=t.get_facecolor())
        for t in ax.collections[1:]
    ]
    labels = [t.get_label() for t in ax.collections[1:]]
    ax.legend(lines, labels)
    plt.show()
    

    enter image description here