Search code examples
pythonmachine-learningscikit-learnlinear-regressiondoc2vec

Linear regression load model doesn't predict as expected


I have trained a linear regression model, with sklearn, for a 5 star rating and it's good enough. I have used Doc2vec to create my vectors, and saved that model. Then I save the linear regression model to another file. What I'm trying to do is load the Doc2vec model and linear regression model and try to predict another review.

There is something very strange about this prediction: whatever the input it always predicts around 2.1-3.0.

Thing is, I have a suggestion that it predicts around the average of 5 (which is 2.5 +/-) but this is not the case. I have printed when training the model the prediction value and the actual value of the test data and they range normally 1-5. So my idea is, that there is something wrong with the loading part of the code. This is my load code:

from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from bs4 import BeautifulSoup
from joblib import dump, load
import pickle
import re

model = Doc2Vec.load('../vectors/750000/doc2vec_model')

def cleanText(text):
    text = BeautifulSoup(text, "lxml").text
    text = re.sub(r'\|\|\|', r' ', text) 
    text = re.sub(r'http\S+', r'<URL>', text)
    text = re.sub(r'[^\w\s]','',text)
    text = text.lower()
    text = text.replace('x', '')
    return text

review = cleanText("Horrible movie! I don't recommend it to anyone!").split()
vector = model.infer_vector(review)

pkl_filename = "../vectors/750000/linear_regression_model.joblib"
with open(pkl_filename, 'rb') as file:  
    linreg = pickle.load(file)

review_vector = vector.reshape(1,-1)
predict_star = linreg.predict(review_vector)
print(predict_star)

Solution

  • Your example code shows imports of both joblib.dump and joblib.load – even though neither is used in this excerpt. And, the suffix of your file is suggestive that the model may have originally been saved with joblib.dump(), not vanilla pickle.

    But, this code shows the file being loaded only via plain pickle.load() – which may be the source of the error.

    The joblib.load() docs suggest that its load() may do things like load numpy arrays from multiple separate files created by its own dump(). (Oddly, the dump() docs are less clear on this, but supposedly dump() has a return-value that may be a list of filenames.)

    You can check where the file was saved for extra files that appear to be related, and try using joblib.load() rather than plain-pickle, to see if that loads a more-functional/more-complete version of your linreg object.