Search code examples
pythonplotlylegendsubplot

how to manage legend_tracegroupgap for different row_heights in subplots in plotly?


i found this example code for subplot with legend at each subplot. i changed it by adding row_heights and now the legend do not fit to the subplots.

import pandas as pd
import plotly.express as px

df = px.data.gapminder().query("continent=='Americas'")


from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=3, cols=1, row_heights=[2,1,0.75]) 

fig.append_trace(go.Scatter(
    x=df.query("country == 'Canada'")['year'],
    y=df.query("country == 'Canada'")['lifeExp'],
    name = 'Canada',
    legendgroup = '1'
), row=1, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'United States'")['year'],
    y=df.query("country == 'United States'")['lifeExp'],
    name = 'United States',
    legendgroup = '1'
), row=1, col=1)

fig.append_trace(go.Scatter(
    x=df.query("country == 'Mexico'")['year'],
    y=df.query("country == 'Mexico'")['lifeExp'],
    name = 'Mexico',
    legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'Colombia'")['year'],
    y=df.query("country == 'Colombia'")['lifeExp'],
    name = 'Colombia',
    legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'Brazil'")['year'],
    y=df.query("country == 'Brazil'")['lifeExp'],
    name = 'Brazil',
    legendgroup = '2'
), row=2, col=1)

fig.append_trace(go.Scatter(
    x=df.query("country == 'Argentina'")['year'],
    y=df.query("country == 'Argentina'")['lifeExp'],
    name = 'Argentina',
    legendgroup = '3'
), row=3, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'Chile'")['year'],
    y=df.query("country == 'Chile'")['lifeExp'],
    name = 'Chile',
    legendgroup = '3'
), row=3, col=1)

fig.update_layout(
    height=800,
    width=800,
    title_text="Life Expectancy in the Americas",
    xaxis3_title = 'Year',
    yaxis1_title = 'Age',
    yaxis2_title = 'Age',
    yaxis3_title = 'Age',
    legend_tracegroupgap = 100,
    yaxis1_range=[50, 90],
    yaxis2_range=[50, 90],
    yaxis3_range=[50, 90]
)
fig.show()

now i am looking for a solution to manage the legend_tracegroupgap for different row_heights. i expect the legends at the top beside the subplots.


Solution

  • As of Plotly v5.15, you can add multiple legends, and position them relative to the height of the plot.

    In your example, you can add the argument legend='legend', legend='legend2', or legend='legend3' to each go.Scatter to match them with the legendgroup, then add the arguments legend = {"y": 1.0}, legend2 = {"y": 0.42},legend3 = {"y": 0.08} to fig.update_layout.

    Below is the full code and resulting figure:

    import pandas as pd
    import plotly.express as px
    
    df = px.data.gapminder().query("continent=='Americas'")
    
    
    from plotly.subplots import make_subplots
    import plotly.graph_objects as go
    
    fig = make_subplots(rows=3, cols=1, row_heights=[2,1,0.75]) 
    
    fig.append_trace(go.Scatter(
        x=df.query("country == 'Canada'")['year'],
        y=df.query("country == 'Canada'")['lifeExp'],
        name = 'Canada',
        legend='legend',
        legendgroup = '1'
    ), row=1, col=1)
    fig.append_trace(go.Scatter(
        x=df.query("country == 'United States'")['year'],
        y=df.query("country == 'United States'")['lifeExp'],
        name = 'United States',
        legend='legend',
        legendgroup = '1'
    ), row=1, col=1)
    
    fig.append_trace(go.Scatter(
        x=df.query("country == 'Mexico'")['year'],
        y=df.query("country == 'Mexico'")['lifeExp'],
        name = 'Mexico',
        legend='legend2',
        legendgroup = '2'
    ), row=2, col=1)
    fig.append_trace(go.Scatter(
        x=df.query("country == 'Colombia'")['year'],
        y=df.query("country == 'Colombia'")['lifeExp'],
        name = 'Colombia',
        legend='legend2',
        legendgroup = '2'
    ), row=2, col=1)
    fig.append_trace(go.Scatter(
        x=df.query("country == 'Brazil'")['year'],
        y=df.query("country == 'Brazil'")['lifeExp'],
        name = 'Brazil',
        legend='legend2',
        legendgroup = '2'
    ), row=2, col=1)
    
    fig.append_trace(go.Scatter(
        x=df.query("country == 'Argentina'")['year'],
        y=df.query("country == 'Argentina'")['lifeExp'],
        name = 'Argentina',
        legend='legend3',
        legendgroup = '3'
    ), row=3, col=1)
    fig.append_trace(go.Scatter(
        x=df.query("country == 'Chile'")['year'],
        y=df.query("country == 'Chile'")['lifeExp'],
        name = 'Chile',
        legend='legend3',
        legendgroup = '3',
    ), row=3, col=1)
    
    fig.update_layout(
        height=800,
        width=800,
        title_text="Life Expectancy in the Americas",
        xaxis3_title = 'Year',
        yaxis1_title = 'Age',
        yaxis2_title = 'Age',
        yaxis3_title = 'Age',
        legend = {"y": 1.0},
        legend2 = {"y": 0.42},
        legend3 = {"y": 0.08},
        yaxis1_range=[50, 90],
        yaxis2_range=[50, 90],
        yaxis3_range=[50, 90]
    )
    fig.show()
    

    enter image description here