Search code examples
python-3.xplotlylegend

Plotly Custom Legend


I have a plotly plot which looks like this:

enter image description here

The Code I am using is below:

fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter( x = pf['Timestamp'], y = pf['Price_A'], name ='<b>A</b>', 
                           mode = 'lines+markers', 
                           marker_color = 'rgba(255, 0, 0, 0.8)', 
                           line = dict(width = 3 ), yaxis = "y1"),  
                           secondary_y=False,)

fig.add_trace(go.Scatter( x = df['Timestamp'], y = df['Price_B'], name='<b>B</b>', 
                           mode = 'lines+markers', 
                           marker_color = 'rgba(0, 196, 128, 0.8)', 
                           line = dict(width = 3 ), yaxis = "y1") , 
                           secondary_y=False,)

for i in pf2['Timestamp']:
   fig.add_vline(x=i, line_width=3, line_dash="dash", line_color="purple", 
                 name='Event')

fig.update_layout( title="<b>Change over Time</b>", font=dict( family="Courier New, 
                    monospace", size=16, color="RebeccaPurple"),
                    legend=dict(
                           yanchor="top",
                           y=0.99,
                           xanchor="left",
                           x=0.01
                    ))

How can I add the entry in the legend for the event that is denoted by the vertical lines?


Solution

  • When you use add_vline, you are adding an annotation which will not have a corresponding legend entry.

    You'll need to instead use go.Scatter to plot the vertical lines, passing the minimum and maximum values in your data (plus or minus some padding) to the y parameter. Then you can set this same y-range for your plot. This will give you the appearance of vertical lines while still showing the full range of your data.

    Update: you can use a legend group so that the vertical lines appear as a single entry in the legend

    For example:

    from pkg_resources import yield_lines
    import plotly.express as px
    import plotly.graph_objects as go
    
    fig = go.Figure()
    
    df = px.data.stocks()
    for col in ['GOOG','AMZN']:
        fig.add_trace(go.Scatter(
            x=df['date'],
            y=df[col]
        ))
    
    vlines = ["2018-07-01","2019-04-01","2019-07-01"]
    min_y,max_y = df[['GOOG','AMZN']].min().min(), df[['GOOG','AMZN']].max().max()
    padding = 0.05*(max_y-min_y)
    for i,x in enumerate(vlines):
        fig.add_trace(go.Scatter(
            x=[x]*2,
            y=[min_y-padding, max_y+padding],
            mode='lines',
            line=dict(color='purple', dash="dash"),
            name="vertical lines",
            legendgroup="vertical lines",
            showlegend=True if i == 0 else False
        ))
    
    fig.update_yaxes(range=[min_y-padding, max_y+padding])
    fig.show()
    

    enter image description here