Search code examples
pythonplotlylegendsubplot

Combine entries in a single legend from Plotly subplots - python


The following plots two separate scatterplots using Plotly. I want to combine the points from each subplot into a single legend. However, if I plot the figure as is, there are some duplicate entries. On the other hand, if I hide a legend from a certain subplot, not all entries are displayed.

df = pd.DataFrame({'Type' : ['1','1','1','1','1','2','2','2','2','2'],
                   'Category' : ['A','D','D','D','F','B','D','A','D','E']
                  })

df['Color'] = df['Category'].map(dict(zip(df['Category'].unique(),
                    px.colors.qualitative.Dark24[:len(df['Category'].unique())])))

df = pd.concat([df]*10, ignore_index = True)

df['Lat'] = np.random.randint(0, 20, 100)
df['Lon'] = np.random.randint(0, 20, 100)

Color = df['Color'].unique()
Category = df['Category'].unique()

cats = dict(zip(Color, Category))

df_type_1 = df[df['Type'] == '1'].copy()
df_type_2 = df[df['Type'] == '2'].copy()

fig = make_subplots(
    rows = 1, 
    cols = 2,
    specs = [[{"type": "scattermapbox"}, {"type": "scattermapbox"}]],
    vertical_spacing = 0.05,
    horizontal_spacing = 0.05
    )

for c in df_type_1['Color'].unique():
    df_color = df_type_1[df_type_1['Color'] == c]
    fig.add_trace(go.Scattermapbox(
                    lat = df_color['Lat'],
                    lon = df_color['Lon'],
                    mode = 'markers',
                    name = cats[c],
                    marker = dict(color = c),
                    opacity = 0.8,
                    #legendgroup = 'group2',
                    #showlegend = True,
                    ),
          row = 1,
          col = 1
         )

for c in df_type_2['Color'].unique():
    df_color = df_type_2[df_type_2['Color'] == c]
    fig.add_trace(go.Scattermapbox(
                    lat = df_color['Lat'],
                    lon = df_color['Lon'],
                    mode = 'markers',
                    name = cats[c],
                    marker = dict(color = c),
                    opacity = 0.8,
                    #legendgroup = 'group2',
                    #showlegend = False,
                    ),
          row = 1,
          col = 2
         )    

 fig.update_layout(height = 600, width = 800, margin = dict(l = 10, r = 10, t = 30, b = 10));

fig.update_layout(mapbox1 = dict(zoom = 2, style = 'carto-positron'),
                  mapbox2 = dict(zoom = 2, style = 'carto-positron'),
                  )

fig.show()

output: duplicate entries

enter image description here

if I use showlegend = False on either subplot, then the legend will not show all applicable entries.

output: (subplot 2 showlegend = False)

enter image description here


Solution

  • The best way to remove duplicate legends at this time is to use set() to remove duplicates from the created legend and update it with that content. I am saving this as a snippet. I am getting the snippet from this answer. I have also changed the method to use the color information set in the columns. I have also redesigned it so that it can be created in a single loop process without creating an extra data frame.

    fig = make_subplots(
        rows = 1, 
        cols = 2,
        specs = [[{"type": "scattermapbox"}, {"type": "scattermapbox"}]],
        vertical_spacing = 0.05,
        horizontal_spacing = 0.05
        )
    
    for t in df['Type'].unique():
        dff = df.query('Type ==@t')
        for c in dff['Category'].unique():
            dffc = dff.query('Category == @c')
            fig.add_trace(go.Scattermapbox(
                    lat = dffc['Lat'],
                    lon = dffc['Lon'],
                    mode = 'markers',
                    name = c,
                    marker = dict(color = dffc['Color']),
                    opacity = 0.8,
                    ),
          row = 1,
          col = int(t)
         )
    fig.update_layout(height = 600, width = 800, margin = dict(l = 10, r = 10, t = 30, b = 10));
    fig.update_layout(mapbox1 = dict(zoom = 2, style = 'carto-positron'),
                      mapbox2 = dict(zoom = 2, style = 'carto-positron'),
                      )
    names = set()
    fig.for_each_trace(
        lambda trace:
            trace.update(showlegend=False)
            if (trace.name in names) else names.add(trace.name))
    
    fig.show()
    

    enter image description here