Search code examples
pythonplotlyplotly-python

Plotly - Renormalise data in response to zooming in on plotly figure


I have some data like:

import pandas as pd
import datetime as dt
import plotly.graph_objects as go

# Example data
returns = pd.DataFrame({'time': [dt.date(2020,1,1), dt.date(2020,1,2), dt.date(2020,1,3), dt.date(2020,1,4), dt.date(2020,1,5), dt.date(2020,1,6)],
                        'longs': [0, 1,2,3,4,3],
                        'shorts': [0, -1,-2,-3,-4,-4]})

and a chart like:

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=list(returns.time),
               y=list(returns.longs),
               name="Longs",
               line=dict(color="#0c56cc")))
fig.add_trace(
    go.Scatter(x=list(returns.time),
               y=list(returns.shorts),
               name="Shorts",
               line=dict(color="#850411", dash="dash")))
fig.show()

Chart

When I zoom in on a particular section however I would like for the data to be renormlised using the first date in the xrange. So zooming in on the period after Jan 5 here would give me something like:

Desired Outcome

This seems to be possible using something like fig.layout.on_change as discussed here. My attempt at doing it is below but it does not appear to change the figure at all.

def renormalise_returns(returns, renorm_date):
    long_data = returns.melt(id_vars = ['time'])
    long_data.sort_values(by = 'time', inplace=True)
    cum_at_start = long_data[long_data['time'] >= renorm_date].groupby(['variable']).first().reset_index().rename(columns = {'value': 'old_value'}).drop(columns = ['time'])
    long_data2 = pd.merge(long_data, cum_at_start, how = 'left', on = ['variable'])
    long_data2['value'] = long_data2['value'] - long_data2['old_value']
    long_data2.drop(columns = ['old_value'], inplace=True)
    finn = long_data2.pivot(index = ['time'], columns = ['variable'], values = 'value').reset_index()
    return finn

fig = go.FigureWidget([go.Scatter(x=returns['time'], y=returns['longs'], name='Longs', line=dict(color="#0c56cc")),
                      go.Scatter(x=returns['time'], y=returns['shorts'], name="Shorts", line=dict(color="#850411", dash="dash"))])
def zoom(xrange):
    xrange_zoom_min, xrange_zoom_max = fig.layout.xaxis.range[0], fig.layout.xaxis.range[1]
    df2 = renormalise_returns(returns, xrange_zoom_min)
    fig = go.FigureWidget([go.Scatter(x=df2['time'], y=df2['longs']), go.Scatter(x=df2['time'], y=df2['shorts'])])

fig.layout.on_change(zoom, 'xaxis.range')
fig.show()

Solution

  • First, you need to convert returns['time'] to prevent JSON serialization error with datetime data (it seems with go.Figure() Plotly handles this but with go.FigureWidget() it doesn't).

    Secondly, by doing fig.layout.on_change(zoom, 'xaxis.range'), the function callback zoom should expect 2 parameters, the first one being the layout object, and the 2nd one the observed xaxis range.

    Also, in the callback you need to update the figure using an update method because by re-assigning a new figure to the fig variable then the reference to the actual (plotted) figure is lost and the change won't be reflected on the plot.

    One last thing (cf. Displaying Figures Using ipywidgets) :

    It is important to note that FigureWidget is not meant to use the renderers framework [...], so you should not use the show() figure method or the plotly.io.show() function on FigureWidget objects.

    In fact, a FigureWidget is also an ipywidgets object. ipywidgets have their own display repr which allows them to be displayed using IPython's display framework. So at the end of the cell, using fig alone or display(fig) should represent the figure properly, the same way as fig.show() but with ipywidgets contexts enabled.

    import pandas as pd
    import datetime as dt
    import plotly.graph_objects as go
    from IPython.display import display
    
    returns = pd.DataFrame({
        'time': [dt.date(2020, 1, 1), dt.date(2020, 1, 2), dt.date(2020, 1, 3), dt.date(2020, 1, 4), dt.date(2020, 1, 5), dt.date(2020, 1, 6)],
        'longs': [0, 1, 2, 3, 4, 3],
        'shorts': [0, -1, -2, -3, -4, -4]
    })
    
    returns['time'] = returns['time'].astype(str)
    
    def renormalise_returns(returns, renorm_date):
        long_data = returns.melt(id_vars = ['time'])
        long_data.sort_values(by = 'time', inplace=True)
        cum_at_start = long_data[long_data['time'] >= renorm_date].groupby(['variable']).first().reset_index().rename(columns = {'value': 'old_value'}).drop(columns = ['time'])
        long_data2 = pd.merge(long_data, cum_at_start, how = 'left', on = ['variable'])
        long_data2['value'] = long_data2['value'] - long_data2['old_value']
        long_data2.drop(columns = ['old_value'], inplace=True)
        finn = long_data2.pivot(index = ['time'], columns = ['variable'], values = 'value').reset_index()
        return finn
    
    fig = go.FigureWidget([
        go.Scatter(x=returns['time'], y=returns['longs'], name='Longs', line=dict(color="#0c56cc")),
        go.Scatter(x=returns['time'], y=returns['shorts'], name="Shorts", line=dict(color="#850411", dash="dash"))
    ])
    
    def zoom(layout, xrange):
        xrange_zoom_min, xrange_zoom_max = xrange
        df2 = renormalise_returns(returns, xrange_zoom_min)
        fig.update_traces(selector=dict(name='Longs'), x=df2['time'], y=df2['longs'])
        fig.update_traces(selector=dict(name='Shorts'), x=df2['time'], y=df2['shorts'])
    
    fig.layout.on_change(zoom, 'xaxis.range')
    
    display(fig)