Search code examples
pythonpython-3.xmatplotlibstockmoving-average

How to get legend to show in graph using matplotlib


Made a simple program to create exponential moving averages for a stock. Code below:

import yfinance as yf
import pandas_datareader as pdr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.style as style
import datetime as dt

yf.pdr_override()

style.use('ggplot')

startyear = 2019
startmonth = 1
startday = 1

start = dt.datetime(startyear, startmonth, startmonth)
end = dt.datetime.now()

stock = input('Enter stock ticker: ')

df = pdr.get_data_yahoo(stock, start, end)

emasUsed = [3, 5, 8, 10, 13, 15, 30, 35, 40, 45, 50, 60]

for x in emasUsed:
    ema = x
    df['EMA_'+str(ema)] = df['Adj Close'].ewm(span=ema, adjust=True).mean()
    df['EMA_'+str(ema)].plot()

plt.show()

I want to graph the moving averages but cannot get the legend to show up unless I graph the EMAs on a separate line like this:

df[['EMA_3', 'EMA_5', 'EMA_8', etc...]].plot()

This is obviously a lot of work to do especially if I want to say add or change the EMAs that I want to get.

Is there any way to get the legend to show up without having to type in each EMA manually?

Thanks, Dan


Solution

  • You can get the axis prior to your plots, then use it to plot your legend. Call it after the plots are done and that's it.

    import yfinance as yf
    import pandas_datareader as pdr
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.style as style
    import datetime as dt
    
    yf.pdr_override()
    
    style.use('ggplot')
    
    startyear = 2019
    startmonth = 1
    startday = 1
    
    start = dt.datetime(startyear, startmonth, startmonth)
    end = dt.datetime.now()
    
    #stock = input('Enter stock ticker: ')
    stock = 'SPY'
    
    df = pdr.get_data_yahoo(stock, start, end)
    
    emasUsed = [3, 5, 8, 10, 13, 15, 30, 35, 40, 45, 50, 60]
    
    fig, ax = plt.subplots(figsize=(10, 8)) # get the axis and additionally set a bigger plot size
    
    for x in emasUsed:
        ema = x
        df['EMA_'+str(ema)] = df['Adj Close'].ewm(span=ema, adjust=True).mean()
        df['EMA_'+str(ema)].plot()
    legend = ax.legend(loc='upper left') # Here's your legend
    
    plt.show()
    

    And the result:

    enter image description here