Search code examples
pythonplotconfidence-intervalsarimax

Plotting confidence interval in SARIMAX prediction data


I am trying to plot confidence interval band along the predicted values off a SARIMAX model.

A SARIMAX model is fitted using this:

model=sm.tsa.statespace.SARIMAX(data_df['Net Sales'],order=(1, 1, 1),seasonal_order=(1,1,1,12))
results=model.fit()
print(results.summary())

To plot the predicted values I am using the following code:

fig, ax = plt.subplots(figsize=(15,5))
ax.ticklabel_format(useOffset=False, style='plain')
data_df['Net_Sales forecast'] = results.predict(start = 48, end = 60, dynamic= True)  
data_df[['Net Sales', 'Net_Sales forecast']].plot(ax=ax, color=['blue', 'orange'], marker='o', legend=True)

Output

I want to plot a confidence interval band of 95% around the forecast data. I have tried various ways but to no avail.

I understand that I can access the parameters for confidence interval in the result of SARIMAX model using the following.

ci = results.conf_int(alpha=0.05)
ci

Returns:

              0               1
ar.L1   -3.633910e-01   1.108174e+00
ma.L1   -1.253388e+00   2.229091e-01
ar.S.L12 -3.360182e+00  4.001006e+00
ma.S.L12 -4.078321e+00  3.517885e+00
sigma2  3.080743e+13    3.080743e+13

How do I incorporate this into the plot to show the confidence interval band?


Solution

  • The confidence intervals you show are actually for model parameters, not for predictions. Here is an example of how you can compute and plot confidence intervals around the predictions, borrowing a dataset used in the statsmodels docs.

    Note: You'll need to be cautious about interpreting these confidence intervals. Here is a relevant page discussing what is actually implemented in statsmodels.

    import matplotlib.pyplot as plt
    import pandas as pd
    import statsmodels.api as sm
    import requests
    from io import BytesIO
    
    # Get data
    wpi1 = requests.get('https://www.stata-press.com/data/r12/wpi1.dta').content
    data = pd.read_stata(BytesIO(wpi1))
    data.index = data.t
    # Set the frequency
    data.index.freq='QS-OCT'
    
    # Fit the model
    model = sm.tsa.statespace.SARIMAX(data['wpi'], trend='c', order=(1,1,1))
    results = model.fit(disp=False)
    
    # Get predictions
    # (can also utilize results.get_forecast(steps=n).summary_frame(alpha=0.05))
    preds_df = (results
                .get_prediction(start='1991-01-01', end='1999-10-01')
                .summary_frame(alpha=0.05)
    )
    print(preds_df.head())
    # wpi               mean   mean_se  mean_ci_lower  mean_ci_upper
    # 1991-01-01  118.358860  0.725041     116.937806     119.779914
    # 1991-04-01  120.340500  1.284361     117.823198     122.857802
    # 1991-07-01  122.167206  1.865597     118.510703     125.823709
    # 1991-10-01  123.858465  2.463735     119.029634     128.687296
    # 1992-01-01  125.431312  3.070871     119.412517     131.450108
    
    # Plot the training data, predicted means and confidence intervals
    fig, ax = plt.subplots(figsize=(15,5))
    ax = data['wpi'].plot(label='Training Data')
    ax.set(
        title='True and Predicted Values, with Confidence Intervals',
        xlabel='Date',
        ylabel='Actual / Predicted Values'
    )
    preds_df['mean'].plot(ax=ax, style='r', label='Predicted Mean')
    ax.fill_between(
        preds_df.index, preds_df['mean_ci_lower'], preds_df['mean_ci_upper'],
        color='r', alpha=0.1
    )
    legend = ax.legend(loc='upper left')
    plt.show()
    

    enter image description here