Search code examples
pythonplotlylegendsubplot

Plotly : Adding legend to subplot


I'm working with python and plotly, and I'm trying to add a legend to each of my subplots. Which is not working...

When I'm running the following code,

import plotly
import numpy as np

fig = plotly.subplots.make_subplots(rows=2, cols=1)

y = np.arange(0,10,1)

fig.add_trace(go.Scatter(y=y,name="name1"), row=1,col=1)
fig.add_trace(go.Scatter(y=y**2,name="name2"), row=1,col=1)


fig.add_trace(go.Scatter(y=y,name="name3"), row=2,col=1)
fig.add_trace(go.Scatter(y=y**2,name="name4"), row=2,col=1)

fig.show()

This is what I get:

enter image description here

But I'd like to have something like this:

enter image description here

Any idea?

Thanks!


Solution

  • When I looked into it, the legend for each subplot seems to correspond from this information. You suggested annotations there as an alternative, so I added them to your code.

    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    import numpy as np
    
    fig = make_subplots(rows=2, cols=1)
    
    y = np.arange(0,10,1)
    
    fig.add_trace(go.Scatter(y=y,name="name1", legendgroup='1'), row=1,col=1)
    fig.add_trace(go.Scatter(y=y**2,name="name2", legendgroup='1'), row=1,col=1)
    
    
    fig.add_trace(go.Scatter(y=y,name="name3", legendgroup='2'), row=2,col=1)
    fig.add_trace(go.Scatter(y=y**2,name="name4", legendgroup='2'), row=2,col=1)
    
    for row in [1,2]:
        if row == 1:
            fig.add_annotation(dict(x=1.0, y=0.7, xref="paper", yref="paper", 
                                    text='name %d' %row, showarrow=False))
            fig.add_annotation(dict(x=1.0, y=1.0, xref="paper", yref="paper", 
                                    text='name %d' %(row+1), showarrow=False))
        else:
            fig.add_annotation(dict(x=1.0, y=0.05, xref="paper", yref="paper", 
                                    text='name %d' %(row+1), showarrow=False))
            fig.add_annotation(dict(x=1.0, y=0.4, xref="paper", yref="paper", 
                                    text='name %d' %(row+2), showarrow=False))
            
    fig.update_layout(legend_tracegroupgap=50)
    
    fig.show()
    

    enter image description here

    Update: The code and graphs have been updated to reflect the introduction of group functionality in the legend.