Search code examples
pandasplotlydata-visualizationdata-scienceplotly-python

How to add multiple Moving Average lines DYNAMICALLY to Plotly Candlestick


I have written this function to plot a candlestick using plotly. I want to add the functionality of Adding N different lines dynamically with different colors. Now it is adding just 1 line. I can hard code another Scatter but what not dynamically.

Here's the code:

def plot_candlesticks(df, names = ('DATE','OPEN','CLOSE','LOW','HIGH'), mv = 44):
        '''
        Plot a candlestick on a given dataframe
        args:
            df: DataFrame
            names: Tuple of column names showing ('DATE','OPEN','HIGH','LOW','OPEN','CLOSE')
            mv: Moving Average
        '''
        stocks = df.copy()
        Date, Open, Close, Low, High = names
        stocks.sort_index(ascending=False, inplace = True)
        stocks[f'{str(mv)}-SMA'] = stocks[Close].rolling(mv, min_periods = 1).mean()

        candle = go.Figure(data = [go.Candlestick(x = stocks[Date], name = 'Trade',
                                                       open = stocks[Open], 
                                                       high = stocks[High], 
                                                       low = stocks[Low], 
                                                       close = stocks[Close]),

                                  go.Scatter(name=f'{str(mv)} MA',x=stocks[Date], y=stocks[f'{str(mv)}-SMA'], 
                                             line=dict(color='blue', width=1)),])

        candle.update_xaxes(
            title_text = 'Date',
            rangeslider_visible = True,
            rangeselector = dict(
                buttons = list([
                    dict(count = 1, label = '1M', step = 'month', stepmode = 'backward'),
                    dict(count = 6, label = '6M', step = 'month', stepmode = 'backward'),
                    dict(count = 1, label = 'YTD', step = 'year', stepmode = 'todate'),
                    dict(count = 1, label = '1Y', step = 'year', stepmode = 'backward'),
                    dict(step = 'all')])))

        candle.update_layout(autosize = True,
                             title = {'text': all_stocks[stocks['SYMBOL'][0]],'y':0.97,'x':0.5,
                                      'xanchor': 'center','yanchor': 'top'},
                             margin=dict(l=30,r=30,b=30,t=30,pad=2),
                             paper_bgcolor="lightsteelblue",)

        candle.update_yaxes(title_text = 'Close Price', tickprefix = u"\u20B9" ) # Rupee symbol
        candle.show()

Solution

  • I have modified the code with the understanding that the intent of your question is to automatically add multiple moving averages to the candlestick graph, with the base value passed to the list as the function argument. The point is achieved by add_trace() to the candlestick graph.

    import plotly.graph_objects as go
    import pandas as pd
    import yfinance as yf
    
    data = yf.download("AAPL", start="2021-01-01", end="2021-03-01")
    data = data.iloc[:,0:4]
    data.reset_index(inplace=True)
    
    def plot_candlesticks(df, names = ('Date','Open','High','Low','Close'), mv = [5,25,75]):
            '''
            Plot a candlestick on a given dataframe
            args:
                df: DataFrame
                names: Tuple of column names showing ('DATE','OPEN','HIGH','LOW','OPEN','CLOSE')
                mv: Moving Average
            '''
            stocks = df.copy()
            Date, Open, Close, Low, High = names
            stocks.sort_index(ascending=False, inplace = True)
            colors = ['red', 'blue', 'yellow']
    
            candle = go.Figure(data = [go.Candlestick(x = stocks[Date], name = 'Trade',
                                                           open = stocks[Open], 
                                                           high = stocks[High], 
                                                           low = stocks[Low], 
                                                           close = stocks[Close]),])
            for i in range(len(mv)):
                stocks[f'{str(mv[i])}-SMA'] = stocks[Close].rolling(mv[i], min_periods = 1).mean()
                candle.add_trace(go.Scatter(name=f'{str(mv[i])} MA',x=stocks[Date], y=stocks[f'{str(mv[i])}-SMA'], 
                                                 line=dict(color=colors[i], width=2)))
    
            candle.update_xaxes(
                title_text = 'Date',
                rangeslider_visible = True,
                rangeselector = dict(
                    buttons = list([
                        dict(count = 1, label = '1M', step = 'month', stepmode = 'backward'),
                        dict(count = 6, label = '6M', step = 'month', stepmode = 'backward'),
                        dict(count = 1, label = 'YTD', step = 'year', stepmode = 'todate'),
                        dict(count = 1, label = '1Y', step = 'year', stepmode = 'backward'),
                        dict(step = 'all')])))
    
            candle.update_layout(autosize = True,
                                 title = {'text': "all_stocks[stocks['SYMBOL'][0]]",'y':0.97,'x':0.5,
                                          'xanchor': 'center','yanchor': 'top'},
                                 margin=dict(l=30,r=30,b=30,t=30,pad=2),
                                 paper_bgcolor="lightsteelblue",)
    
            candle.update_yaxes(title_text = 'Close Price', tickprefix = u"\u20B9" ) # Rupee symbol
            candle.show()
    
    plot_candlesticks(data)
    

    enter image description here