Search code examples
pythonmatplotlibstatsmodelsvector-auto-regression

Modifying a statsmodels graph


I am following the statsmodels documentation here: https://www.statsmodels.org/stable/vector_ar.html

I get to the part at the middle of the page that says: irf.plot(orth=False)

which produces the following graph for my data: enter image description here

I need to modify the elements of the graph. E.g., I need to apply tight_layout and also decrease the y-tick sizes so that they don't get into the graphs to their left.

The documentation talks about passing "subplot plotting funcions" in to the subplot argument of irf.plot(). But when I try something like:

irf.plot(subplot_params = {'fontsize': 8, 'figsize' : (100, 100), 'tight_layout': True})

only the fontsize parameter works. I also tried passing these parameters to the 'plot_params' argument but of no avail.

So, my question is how can I access other parameters of this irf.plot, especially the figsize and ytick sizes? I also need to force it to print a grid, as well as all values on the x axis (1, 2, 3, 4, ..., 10)

Is there any way I can create a blank plot using the fig, ax = plt.subplots() way and then create the irf.plot on that figure?


Solution

  • Looks like the function returns a matplotlib.figure:

    Try doing this:

    fig = irf.plot(orth=False,..)
    fig.tight_layout()
    fig.set_figheight(100)
    fig.set_figwidth(100)
    

    If I run it with this example, it works:

    import numpy as np
    import pandas
    import statsmodels.api as sm
    from statsmodels.tsa.api import VAR
    
    mdata = sm.datasets.macrodata.load_pandas().data
    dates = mdata[['year', 'quarter']].astype(int).astype(str)
    quarterly = dates["year"] + "Q" + dates["quarter"]
    from statsmodels.tsa.base.datetools import dates_from_str
    quarterly = dates_from_str(quarterly)
    mdata = mdata[['realgdp','realcons','realinv']]
    mdata.index = pandas.DatetimeIndex(quarterly)
    data = np.log(mdata).diff().dropna()
    model = VAR(data)
    
    results = model.fit(maxlags=15, ic='aic')
    irf = results.irf(10)
    
    fig = irf.plot(orth=False)
    fig.tight_layout()
    fig.set_figheight(30)
    fig.set_figwidth(30)
    

    enter image description here