Search code examples
pythonplotlysubplotlegend-properties

add color legend to map subplots


For this example, I modified the dataset and code from https://plotly.com/python/map-subplots-and-small-multiples/ to add a column to plot it in colors.

What I want to do here and in my dataset is to add a legend with a color scale.

Here the range is the same (0-9) among maps, so, a general legend or a legend for each subplot would work. Here the general legend is wrong.

Related: Plotly contour subplots each having their own colorbar

import plotly.graph_objects as go
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/1962_2006_walmart_store_openings.csv')
df.head()

# new column
import numpy as np
df["counts"] = np.random.choice(range(0,10),df.shape[0])

data = []
layout = dict(
    title = 'New Walmart Stores per year 1962-2006<br>\
Source: <a href="http://www.econ.umn.edu/~holmes/data/WalMart/index.html">\
University of Minnesota</a>',
    # showlegend = False,
    autosize = False,
    width = 1000,
    height = 900,
    hovermode = False,
    legend = dict(
        x=0.7,
        y=-0.1,
        bgcolor="rgba(255, 255, 255, 0)",
        font = dict( size=11 ),
    )
)
years = df['YEAR'].unique()

for i in range(len(years)):
    geo_key = 'geo'+str(i+1) if i != 0 else 'geo'
    lons = list(df[ df['YEAR'] == years[i] ]['LON'])
    lats = list(df[ df['YEAR'] == years[i] ]['LAT'])
    mycolor = list(df[ df['YEAR'] == years[i] ]['counts']) # new
    # Walmart store data
    data.append(
        dict(
            type = 'scattergeo',
            showlegend=False,
            lon = lons,
            lat = lats,
            geo = geo_key,
            name = int(years[i]),
            marker = dict(
                color = mycolor, # new
                #color = "rgb(0, 0, 255)",
                opacity = 0.5
                ,showscale=True # new
            )
        )
    )
    # Year markers
    data.append(
        dict(
            type = 'scattergeo',
            showlegend = False,
            lon = [-78],
            lat = [47],
            geo = geo_key,
            text = [years[i]],
            mode = 'text',
        )
    )
    layout[geo_key] = dict(
        scope = 'usa',
        showland = True,
        landcolor = 'rgb(229, 229, 229)',
        showcountries = False,
        domain = dict( x = [], y = [] ),
        subunitcolor = "rgb(255, 255, 255)",
    )
   
z = 0
COLS = 5
ROWS = 9
for y in reversed(range(ROWS)):
    for x in range(COLS):
        geo_key = 'geo'+str(z+1) if z != 0 else 'geo'
        layout[geo_key]['domain']['x'] = [float(x)/float(COLS), float(x+1)/float(COLS)]
        layout[geo_key]['domain']['y'] = [float(y)/float(ROWS), float(y+1)/float(ROWS)]
        z=z+1
        if z > 42:
            break

fig = go.Figure(data=data, layout=layout)
fig.update_layout(width=800)
config = {'staticPlot': True}
fig.show(config=config)

enter image description here


Solution

  • I had to build the coordinates for each colorbar:

    import plotly.graph_objects as go
    import pandas as pd
    df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/1962_2006_walmart_store_openings.csv')
    df.head()
    
    z1 = z = 0
    COLS = 5
    ROWS = 9
    
    mylist = [[],[]]
    
    for y in reversed(range(ROWS)):
        for x in range(COLS):
            mylist[0].append((float(x+1)/float(COLS))-.02)
            mylist[1].append((float(y+1)/float(ROWS))-.05)
            z1=z1+1
            if z1 > 42:
                break
    
    # new column
    import numpy as np
    df["counts"] = np.random.choice(range(0,10),df.shape[0])
    
    data = []
    layout = dict(
        title = 'New Walmart Stores per year 1962-2006<br>\
    Source: <a href="http://www.econ.umn.edu/~holmes/data/WalMart/index.html">\
    University of Minnesota</a>',
        # showlegend = False,
        autosize = False,
        width = 1000,
        height = 900,
        hovermode = False,
        legend = dict(
            x=0.7,
            y=-0.1,
            bgcolor="rgba(255, 255, 255, 0)",
            font = dict( size=11 ),
        )
    )
    years = df['YEAR'].unique()
    
    for i in range(len(years)):
        geo_key = 'geo'+str(i+1) if i != 0 else 'geo'
        lons = list(df[ df['YEAR'] == years[i] ]['LON'])
        lats = list(df[ df['YEAR'] == years[i] ]['LAT'])
        mycolor = list(df[ df['YEAR'] == years[i] ]['counts']) # new
        # Walmart store data
        data.append(
            dict(
                type = 'scattergeo',
                showlegend=False,
                lon = lons,
                lat = lats,
                geo = geo_key,
                name = int(years[i]),
                marker = dict(
                    color = mycolor, # new
                    #color = "rgb(0, 0, 255)",
                    opacity = 0.5,
                    showscale=True
                    ,colorbar=dict(len=0.1
                                   , x=mylist[0][i]
                                   , y=mylist[1][i]
                                   ,thickness=5
                                  ) 
                )
            )
        )
        # Year markers
        data.append(
            dict(
                type = 'scattergeo',
                showlegend = False,
                lon = [-78],
                lat = [47],
                geo = geo_key,
                text = [years[i]],
                mode = 'text',
            )
        )
        layout[geo_key] = dict(
            scope = 'usa',
            showland = True,
            landcolor = 'rgb(229, 229, 229)',
            showcountries = False,
            domain = dict( x = [], y = [] ),
            subunitcolor = "rgb(255, 255, 255)",
        )
    
    for y in reversed(range(ROWS)):
        for x in range(COLS):
            geo_key = 'geo'+str(z+1) if z != 0 else 'geo'
            layout[geo_key]['domain']['x'] = [float(x)/float(COLS), float(x+1)/float(COLS)]
            layout[geo_key]['domain']['y'] = [float(y)/float(ROWS), float(y+1)/float(ROWS)]
            z=z+1
            if z > 42:
                break
    
    fig = go.Figure(data=data, layout=layout)
    fig.update_layout(width=800)
    config = {'staticPlot': True}
    fig.show(config=config)