i found this example code for subplot with legend at each subplot. i changed it by adding row_heights and now the legend do not fit to the subplots.
import pandas as pd
import plotly.express as px
df = px.data.gapminder().query("continent=='Americas'")
from plotly.subplots import make_subplots
import plotly.graph_objects as go
fig = make_subplots(rows=3, cols=1, row_heights=[2,1,0.75])
fig.append_trace(go.Scatter(
x=df.query("country == 'Canada'")['year'],
y=df.query("country == 'Canada'")['lifeExp'],
name = 'Canada',
legendgroup = '1'
), row=1, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'United States'")['year'],
y=df.query("country == 'United States'")['lifeExp'],
name = 'United States',
legendgroup = '1'
), row=1, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Mexico'")['year'],
y=df.query("country == 'Mexico'")['lifeExp'],
name = 'Mexico',
legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Colombia'")['year'],
y=df.query("country == 'Colombia'")['lifeExp'],
name = 'Colombia',
legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Brazil'")['year'],
y=df.query("country == 'Brazil'")['lifeExp'],
name = 'Brazil',
legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Argentina'")['year'],
y=df.query("country == 'Argentina'")['lifeExp'],
name = 'Argentina',
legendgroup = '3'
), row=3, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Chile'")['year'],
y=df.query("country == 'Chile'")['lifeExp'],
name = 'Chile',
legendgroup = '3'
), row=3, col=1)
fig.update_layout(
height=800,
width=800,
title_text="Life Expectancy in the Americas",
xaxis3_title = 'Year',
yaxis1_title = 'Age',
yaxis2_title = 'Age',
yaxis3_title = 'Age',
legend_tracegroupgap = 100,
yaxis1_range=[50, 90],
yaxis2_range=[50, 90],
yaxis3_range=[50, 90]
)
fig.show()
now i am looking for a solution to manage the legend_tracegroupgap for different row_heights. i expect the legends at the top beside the subplots.
As of Plotly v5.15
, you can add multiple legends, and position them relative to the height of the plot.
In your example, you can add the argument legend='legend'
, legend='legend2'
, or legend='legend3'
to each go.Scatter
to match them with the legendgroup, then add the arguments legend = {"y": 1.0}, legend2 = {"y": 0.42},legend3 = {"y": 0.08}
to fig.update_layout
.
Below is the full code and resulting figure:
import pandas as pd
import plotly.express as px
df = px.data.gapminder().query("continent=='Americas'")
from plotly.subplots import make_subplots
import plotly.graph_objects as go
fig = make_subplots(rows=3, cols=1, row_heights=[2,1,0.75])
fig.append_trace(go.Scatter(
x=df.query("country == 'Canada'")['year'],
y=df.query("country == 'Canada'")['lifeExp'],
name = 'Canada',
legend='legend',
legendgroup = '1'
), row=1, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'United States'")['year'],
y=df.query("country == 'United States'")['lifeExp'],
name = 'United States',
legend='legend',
legendgroup = '1'
), row=1, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Mexico'")['year'],
y=df.query("country == 'Mexico'")['lifeExp'],
name = 'Mexico',
legend='legend2',
legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Colombia'")['year'],
y=df.query("country == 'Colombia'")['lifeExp'],
name = 'Colombia',
legend='legend2',
legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Brazil'")['year'],
y=df.query("country == 'Brazil'")['lifeExp'],
name = 'Brazil',
legend='legend2',
legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Argentina'")['year'],
y=df.query("country == 'Argentina'")['lifeExp'],
name = 'Argentina',
legend='legend3',
legendgroup = '3'
), row=3, col=1)
fig.append_trace(go.Scatter(
x=df.query("country == 'Chile'")['year'],
y=df.query("country == 'Chile'")['lifeExp'],
name = 'Chile',
legend='legend3',
legendgroup = '3',
), row=3, col=1)
fig.update_layout(
height=800,
width=800,
title_text="Life Expectancy in the Americas",
xaxis3_title = 'Year',
yaxis1_title = 'Age',
yaxis2_title = 'Age',
yaxis3_title = 'Age',
legend = {"y": 1.0},
legend2 = {"y": 0.42},
legend3 = {"y": 0.08},
yaxis1_range=[50, 90],
yaxis2_range=[50, 90],
yaxis3_range=[50, 90]
)
fig.show()