Search code examples
pythonplotlyplotly-python

How to dynamically change the scale/ticks of y axis in plotly charts upon zooming in?


I am trying to make a candle stick chart using plotly. I am using stock data spanning over 10 years. Due to this the candles appear very small as the y axis has a large scale. However if I zoom into a smaller time period (lets say any 1 month in the 10 years) I want the y axis scale to change so that the candle looks big. Below is my code:

df_stockData = pdr.DataReader('TSLA', data_source='yahoo', start='2011-11-04', end='2021-11-04')

fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_width=[0.25, 0.75])
fig.add_trace(go.Candlestick(
    x=df_stockData.index,
    open=df_stockData['Open'],
    high=df_stockData['High'],
    low=df_stockData['Low'],
    close=df_stockData['Close'],
    increasing_line_color='green',
    decreasing_line_color='red',
    showlegend=False
), row=1, col=1)
fig.add_trace(go.Scatter(
    x=df_stockData.index,
    y=df_stockData['RSI_14'],
    line=dict(color='#ff9900', width=2),
    showlegend=False,
), row=2, col=1
)
fig.show()

My chart looks as follows: enter image description here

As you can see the y-axis (stock price) has a very large scale. Even if I zoom in to a smaller time period the y axis scale remains the same. Is there any way to make the y-axis scale change dynamically so that the candles appear bigger when I zoom in?

enter image description here


Solution

  • This approach uses a callback with dash to set the range of y-axis based on values selected in range slider.

    A significant amount of the code is making your figure an MWE (calc RSI_14)

    import pandas_datareader as pdr
    from plotly.subplots import make_subplots
    import plotly.graph_objects as go
    import numpy as np
    import dash
    from dash.dependencies import Input, Output, State
    from jupyter_dash import JupyterDash
    
    
    # make dataframe complete ...
    df_stockData = pdr.DataReader(
        "TSLA", data_source="yahoo", start="2011-11-04", end="2021-11-04"
    )
    
    
    def rma(x, n, y0):
        a = (n - 1) / n
        ak = a ** np.arange(len(x) - 1, -1, -1)
        return np.r_[
            np.full(n, np.nan),
            y0,
            np.cumsum(ak * x) / ak / n + y0 * a ** np.arange(1, len(x) + 1),
        ]
    
    
    n = 14
    df = df_stockData
    df["change"] = df["Close"].diff()
    df["gain"] = df.change.mask(df.change < 0, 0.0)
    df["loss"] = -df.change.mask(df.change > 0, -0.0)
    df["avg_gain"] = rma(
        df.gain[n + 1 :].to_numpy(), n, np.nansum(df.gain.to_numpy()[: n + 1]) / n
    )
    df["avg_loss"] = rma(
        df.loss[n + 1 :].to_numpy(), n, np.nansum(df.loss.to_numpy()[: n + 1]) / n
    )
    df["rs"] = df.avg_gain / df.avg_loss
    df["RSI_14"] = 100 - (100 / (1 + df.rs))
    
    
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_width=[0.25, 0.75])
    fig.add_trace(
        go.Candlestick(
            x=df_stockData.index,
            open=df_stockData["Open"],
            high=df_stockData["High"],
            low=df_stockData["Low"],
            close=df_stockData["Close"],
            increasing_line_color="green",
            decreasing_line_color="red",
            showlegend=False,
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=df_stockData.index,
            y=df_stockData["RSI_14"],
            line=dict(color="#ff9900", width=2),
            showlegend=False,
        ),
        row=2,
        col=1,
    )
    
    
    # Build App
    app = JupyterDash(__name__)
    
    app.layout = dash.html.Div(
        [
            dash.dcc.Graph(
                id="fig",
                figure=fig,
            ),
        ]
    )
    
    
    @app.callback(
        Output("fig", "figure"),
        Input("fig", "relayoutData"),
    )
    def scaleYaxis(rng):
        if rng and "xaxis.range" in rng.keys():
            try:
                d = df_stockData.loc[
                    rng["xaxis.range"][0] : rng["xaxis.range"][1],
                    ["High", "Low", "Open", "Close"],
                ]
                if len(d) > 0:
                    fig["layout"]["yaxis"]["range"] = [d.min().min(), d.max().max()]
            except KeyError:
                pass
            finally:
                fig["layout"]["xaxis"]["range"] = rng["xaxis.range"]
    
        return fig
    
    
    app.run_server(mode="inline")
    

    enter image description here