Search code examples
pythonmachine-learningstatsmodelsforecasting

Use statsmodels model fit on another dataset


Suppose I fit a model on the dataset dataset1 using SARIMAX from statsmodels.tsa.statespace.sarimax - is it possible to then use this fit to make predictions on another dataset dataset2?

Namely, consider the following:

from statsmodels.tsa.statespace.sarimax import SARIMAX
import pandas as pd
import numpy as np

# generate example data
n=90
idx = pd.PeriodIndex(pd.date_range(start = '2015-01-02',end='2015-04-01',freq='D'))
dat = np.sin(np.linspace(0,12*np.pi,n)) + np.random.randn(n)/10
dataset1 = pd.Series(dat, index = idx)

# fit model
fit = SARIMAX(dataset1, order = (1,0,1)).fit()
# make 30 day forecast on dataset1
fit.forecast(30)

How would I go about using fit to make a prediction on dataset2?

dat = np.sin(np.linspace(0,12*np.pi,n)) + np.random.randn(n)/10
dataset2 = pd.Series(dat, index = idx)

Ideally, it'd be something super simple akin to fit(dataset2).forecast(30) but that clearly isn't the case.

I know I can extract the estimated parameters fit.params but short of going through this tedious process, is there a built-in way or a hack to using the existing fit instance?


Solution

  • You can use the apply results method:

    from statsmodels.tsa.statespace.sarimax import SARIMAX
    import pandas as pd
    import numpy as np
    
    # generate example data
    n=90
    idx = pd.PeriodIndex(pd.date_range(start = '2015-01-02',end='2015-04-01',freq='D'))
    dat = np.sin(np.linspace(0,12*np.pi,n)) + np.random.randn(n)/10
    dataset1 = pd.Series(dat, index = idx)
    
    # fit model
    fit = SARIMAX(dataset1, order = (1,0,1)).fit()
    # make 30 day forecast on dataset1
    fit.forecast(30)
    
    # ------------------------------------
    
    # get the new dataset
    dat = np.sin(np.linspace(0,12*np.pi,n)) + np.random.randn(n)/10
    dataset2 = pd.Series(dat, index = idx)
    
    # apply the parameters from `fit` to the new dataset
    fit2 = fit.apply(dataset2)
    # make 30 day forecast on dataset2
    fit2.forecast(30)