Search code examples
pythonmatplotlibmplfinance

Shading regions inside an mplfinance chart


I am using matplotlib v 3.7.0, mplfinance version '0.12.9b7', and Python 3.10.

I am trying to shade regions of a plot, and although my logic seems correct, the shaded areas are not being displayed on the plot.

This is my code:

import yfinance as yf
import mplfinance as mpf
import pandas as pd

# Download the stock data
df = yf.download('TSLA', start='2022-01-01', end='2022-03-31')

# Define the date ranges for shading
red_range = ['2022-01-15', '2022-02-15']
blue_range = ['2022-03-01', '2022-03-15']


# Create a function to shade the chart regions
def shade_region(ax, region_dates, color):
    region_dates.sort()

    start_date = region_dates[0]
    end_date = region_dates[1]

    # plot vertical lines
    ax.axvline(pd.to_datetime(start_date), color=color, linestyle='--')
    ax.axvline(pd.to_datetime(end_date), color=color, linestyle='--')

    # create fill
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    ax.fill_between(pd.date_range(start=start_date, end=end_date), ymin, ymax, alpha=0.2, color=color)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)


# Plot the candlestick chart with volume
fig, axlist = mpf.plot(df, type='candle', volume=True, style='charles', 
                        title='TSLA Stock Price', ylabel='Price ($)', ylabel_lower='Shares\nTraded', 
                        figratio=(2,1), figsize=(10,5), tight_layout=True, returnfig=True)

# Get the current axis object
ax = axlist[0]

# Shade the regions on the chart
shade_region(ax, red_range, 'red')
shade_region(ax, blue_range, 'blue')


# Show the plot
mpf.show()

Why are the selected regions not being shaded, and how do I fix this?


Solution

  • The problem is that, when show_nontrading=False (which is the default value when not specified) then the X-axis are not dates as you would expect. Thus the vertical lines and the fill_between that you are specifying by date are actually ending up way off the chart.

    The simplest solution is to set show_nontrading=True. Using your code:

    fig, axlist = mpf.plot(df, type='candle', volume=True, style='charles', 
                           title='TSLA Stock Price', ylabel='Price ($)',
                           ylabel_lower='Shares\nTraded', figratio=(2,1), 
                           figsize=(10,5), tight_layout=True, returnfig=True,
                           show_nontrading=True)
    
    # Get the current axis object
    ax = axlist[0]
    
    # Shade the regions on the chart
    shade_region(ax, red_range, 'red')
    shade_region(ax, blue_range, 'blue')
    
    # Show the plot
    mpf.show()
    

    enter image description here


    There are two other solutions, to the problem, that allow you to leave show_nontrading=False if that is your preference.

    1. The first solution is to use mplfinance's kwargs, and do not use returnfig. Here is the documentation:

    This is the prefered solution since it is always a good idea to let mplfinance do all manipulation of Axes objects, unless there is something that you cannot accomplish otherwise.

    And here is an example modifying your code:

    red_range = ['2022-01-15', '2022-02-15']
    blue_range = ['2022-03-01', '2022-03-15']
    
    vline_dates  = red_range + blue_range
    vline_colors = ['red','red','blue','blue']
    vline_dict   = dict(vlines=vline_dates,colors=vline_colors,line_style='--')
    
    ymax = max(df['High'].values)
    ymin = min(df['Low'].values)
    
    # create a dataframe from the datetime index 
    # for using in generating the fill_between `where` values:
    dfdates = df.index.to_frame()
    
    # generate red boolean where values:
    where_values = pd.notnull(dfdates[(dfdates >= red_range[0]) & (dfdates <= red_range[1])].Date.values)
    
    # put together the red fill_between specification:
    fb_red = dict(y1=ymin,y2=ymax,where=where_values,alpha=0.2,color='red')
    
    # generate blue boolean where values:
    where_values = pd.notnull(dfdates[(dfdates >= blue_range[0]) & (dfdates <= blue_range[1])].Date.values)
    
    # put together the red fill_between specification:
    fb_blue = dict(y1=ymin,y2=ymax,where=where_values,alpha=0.2,color='blue')
    
    # Plot the candlestick chart with volume
    mpf.plot(df, type='candle', volume=True, style='charles', 
             title='TSLA Stock Price', ylabel='Price ($)', label_lower='Shares\nTraded', 
             figratio=(2,1), figsize=(10,5), tight_layout=True, 
             vlines=vline_dict, 
             fill_between=[fb_red,fb_blue])
    

    enter image description here


    Notice there is a slight space between the left most vertical line and the shaded area. That is because the date you have selected ('2022-01-15') is on the weekend (and a 3-day weekend at that). If you change the date to '2022-01-14' or '2022-01-18' it will work fine, as shown here:

    enter image description here



    2. The last solution requires returnfig=True. This is not the recommended solution but it does work.

    First, it's important to understand the following: When show_nontrading is not specified, it defaults to False, which means that, although you see datetimes displayed on the x-axis, the actual values are the row number of your dataframe. Click here for a more detailed explanation.

    Therefore,in your code, instead of specifying dates, you specify the row number where that date appears.

    The simplest way to specify the row number is to use function date_to_iloc(df.index.to_series(),date) as defined here:

    def date_to_iloc(dtseries,date):
        '''Convert a `date` to a location, given a date series w/a datetime index. 
           If `date` does not exactly match a date in the series then interpolate between two dates.
           If `date` is outside the range of dates in the series, then raise an exception
          .
        '''
        d1s = dtseries.loc[date:]
        if len(d1s) < 1:
            sdtrange = str(dtseries[0])+' to '+str(dtseries[-1])
            raise ValueError('User specified line date "'+str(date)+
                             '" is beyond (greater than) range of plotted data ('+sdtrange+').')
        d1 = d1s.index[0]
        d2s = dtseries.loc[:date]
        if len(d2s) < 1:
            sdtrange = str(dtseries[0])+' to '+str(dtseries[-1])
            raise ValueError('User specified line date "'+str(date)+
                             '" is before (less than) range of plotted data ('+sdtrange+').')
        d2 = dtseries.loc[:date].index[-1]
        # If there are duplicate dates in the series, for example in a renko plot
        # then .get_loc(date) will return a slice containing all the dups, so:
        loc1 = dtseries.index.get_loc(d1)
        if isinstance(loc1,slice): loc1 = loc1.start
        loc2 = dtseries.index.get_loc(d2)
        if isinstance(loc2,slice): loc2 = loc2.stop - 1
        return (loc1+loc2)/2.0
    

    The function takes as input the data frame index converted to a series. So the following changes to your code will allow it to work using this method:

    # Define the date ranges for shading
    red_range = [date_to_iloc(df.index.to_series(),dt) for dt in ['2022-01-15', '2022-02-15']]
    blue_range = [date_to_iloc(df.index.to_series(),dt) for dt in ['2022-03-01', '2022-03-15']]
    
    ...
    
        ax.axvline(start_date, color=color, linestyle='--')
        ax.axvline(end_date, color=color, linestyle='--')
    
    ...
    
        ax.fill_between([start_date,end_date], ymin, ymax, alpha=0.2, color=color)
    

    Everything else stays the same and you get: enter image description here