Here is the code. In the actual code there are two other regressions and their results also end up writing on the same figure as shown in the image below
import pandas as pd
import os
import statsmodels.api as sm
import matplotlib.pyplot as plt
IN_PATH = os.path.join("data", "clean", "imdb_clean.csv")
OUTPUT_DIR = "quantitative analysis"
REVENUE_IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "revenue_imdb_ols_regression.png")
IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "imdb_ols_regression.png")
df = pd.read_csv(IN_PATH)
dummy_cols = df.columns[10:-1]
def revenue_imdb_ols_regression(out_path):
'''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
x_cols = ["IMDBRating", "ReleaseYear"]
for col in dummy_cols:
x_cols.append(col)
x = df[x_cols]
y = df["GrossRevenue"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
plt.rc("figure", figsize=(12, 7))
plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
plt.axis("off")
plt.tight_layout()
plt.savefig(out_path)
def imdb_ols_regression(out_path):
'''Perform OLS regression of IMBD Rating on genre dummies and create csv'''
x = df[dummy_cols]
y = df["IMDBRating"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
plt.rc("figure", figsize=(12, 7))
plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
plt.axis("off")
plt.tight_layout()
plt.savefig(out_path)
if __name__ == "__main__":
os.makedirs(OUTPUT_DIR, exist_ok=True)
revenue_imdb_ols_regression(REVENUE_IMDB_OLS_PATH)
imdb_ols_regression(IMDB_OLS_PATH)
def revenue_imdb_ols_regression(out_path):
'''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
x_cols = ["IMDBRating", "ReleaseYear"]
for col in dummy_cols:
x_cols.append(col)
x = df[x_cols]
y = df["GrossRevenue"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
fig, ax = plt.subplots(figsize=(12, 7))
ax.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
ax.axis("off")
plt.tight_layout()
fig.savefig(out_path)
fig.set_tight_layout(True)
instead of plt.tight_layout()
might work better - try it out