I am trying to plot confidence interval band along the predicted values off a SARIMAX model.
A SARIMAX model is fitted using this:
model=sm.tsa.statespace.SARIMAX(data_df['Net Sales'],order=(1, 1, 1),seasonal_order=(1,1,1,12))
To plot the predicted values I am using the following code:
fig, ax = plt.subplots(figsize=(15,5))
ax.ticklabel_format(useOffset=False, style='plain')
data_df['Net_Sales forecast'] = results.predict(start = 48, end = 60, dynamic= True)
data_df[['Net Sales', 'Net_Sales forecast']].plot(ax=ax, color=['blue', 'orange'], marker='o', legend=True)
I want to plot a confidence interval band of 95% around the forecast data. I have tried various ways but to no avail.
I understand that I can access the parameters for confidence interval in the result of SARIMAX model using the following.
ci = results.conf_int(alpha=0.05)
0 1
ar.L1 -3.633910e-01 1.108174e+00
ma.L1 -1.253388e+00 2.229091e-01
ar.S.L12 -3.360182e+00 4.001006e+00
ma.S.L12 -4.078321e+00 3.517885e+00
sigma2 3.080743e+13 3.080743e+13
How do I incorporate this into the plot to show the confidence interval band?
The confidence intervals you show are actually for model parameters, not for predictions. Here is an example of how you can compute and plot confidence intervals around the predictions, borrowing a dataset used in the statsmodels
Note: You'll need to be cautious about interpreting these confidence intervals. Here is a relevant page discussing what is actually implemented in statsmodels
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels.api as sm
import requests
from io import BytesIO
# Get data
wpi1 = requests.get('https://www.stata-press.com/data/r12/wpi1.dta').content
data = pd.read_stata(BytesIO(wpi1))
data.index = data.t
# Set the frequency
# Fit the model
model = sm.tsa.statespace.SARIMAX(data['wpi'], trend='c', order=(1,1,1))
results = model.fit(disp=False)
# Get predictions
# (can also utilize results.get_forecast(steps=n).summary_frame(alpha=0.05))
preds_df = (results
.get_prediction(start='1991-01-01', end='1999-10-01')
# wpi mean mean_se mean_ci_lower mean_ci_upper
# 1991-01-01 118.358860 0.725041 116.937806 119.779914
# 1991-04-01 120.340500 1.284361 117.823198 122.857802
# 1991-07-01 122.167206 1.865597 118.510703 125.823709
# 1991-10-01 123.858465 2.463735 119.029634 128.687296
# 1992-01-01 125.431312 3.070871 119.412517 131.450108
# Plot the training data, predicted means and confidence intervals
fig, ax = plt.subplots(figsize=(15,5))
ax = data['wpi'].plot(label='Training Data')
title='True and Predicted Values, with Confidence Intervals',
ylabel='Actual / Predicted Values'
preds_df['mean'].plot(ax=ax, style='r', label='Predicted Mean')
preds_df.index, preds_df['mean_ci_lower'], preds_df['mean_ci_upper'],
color='r', alpha=0.1
legend = ax.legend(loc='upper left')