Search code examples
pythonmatplotlibstatsmodels

matplotlib keeps writing over the same figure. I need regression results as separate png's. How should I modify the code?


Here is the code. In the actual code there are two other regressions and their results also end up writing on the same figure as shown in the image below

import pandas as pd
import os
import statsmodels.api as sm
import matplotlib.pyplot as plt

IN_PATH = os.path.join("data", "clean", "imdb_clean.csv")
OUTPUT_DIR = "quantitative analysis"
REVENUE_IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "revenue_imdb_ols_regression.png")
IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "imdb_ols_regression.png")

df = pd.read_csv(IN_PATH)
dummy_cols = df.columns[10:-1]


def revenue_imdb_ols_regression(out_path):
    '''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
    
    x_cols = ["IMDBRating", "ReleaseYear"]
    for col in dummy_cols:
        x_cols.append(col)

    x = df[x_cols]
    y = df["GrossRevenue"]
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    model_summary = model.summary()
    
    
    plt.rc("figure", figsize=(12, 7))
    plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path)

def imdb_ols_regression(out_path):
    '''Perform OLS regression of IMBD Rating on genre dummies and create csv'''
    
    x = df[dummy_cols]
    y = df["IMDBRating"]

    model = sm.OLS(y, sm.add_constant(x)).fit()
    model_summary = model.summary()
    
    
    plt.rc("figure", figsize=(12, 7))
    plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path)

if __name__ == "__main__":
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    revenue_imdb_ols_regression(REVENUE_IMDB_OLS_PATH)
    imdb_ols_regression(IMDB_OLS_PATH)

enter image description here


Solution

  • def revenue_imdb_ols_regression(out_path):
        '''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
        
        x_cols = ["IMDBRating", "ReleaseYear"]
        for col in dummy_cols:
            x_cols.append(col)
    
        x = df[x_cols]
        y = df["GrossRevenue"]
        
        model = sm.OLS(y, sm.add_constant(x)).fit()
        model_summary = model.summary()
        
        
        fig, ax = plt.subplots(figsize=(12, 7))
        
        ax.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
        ax.axis("off")
        plt.tight_layout()
        fig.savefig(out_path)
    

    fig.set_tight_layout(True) instead of plt.tight_layout()might work better - try it out