Search code examples
pythonscikit-learngensimword2vec

How to properly recover a Word2Vec model created using a SKLearn wrapper?


I am trying to create and store a gensin Word2Vec model using the fit function, then turn it into a SKLearn pipeline, pickle it, to later use it with transform on new data.

I created the wrapper, but the self.w2v object seems not to have been fitted and does not recognize any word. It is as if self.w2v had never seen any word.

Any ideas about how to address this?

from sklearn.base import TransformerMixin, BaseEstimator
from gensim.models import Word2Vec

class SentenceVectorizer(TransformerMixin, BaseEstimator):

    def __init__(self, vector_size=50):
        self.vector_size = vector_size
   
    def sent_vectorizer(self, sentence, vectorizer):
        '''
        Applies the fitted W2V model for each token of each sentence and returns their vector representation.
        '''

        sent_vec =[]
        numw = 0

        for word in sentence:
            try:
                if numw == 0:
                    sent_vec = vectorizer.wv[word]       
                else:
                    sent_vec = np.add(sent_vec, vectorizer.wv[word])
                numw += 1

            except: # if word not present
                if numw == 0:
                    sent_vec = np.zeros(self.vector_size)
                else:
                    sent_vec = np.add(sent_vec, np.zeros(self.vector_size))

        if numw > 0:
            return np.asarray(sent_vec) / numw
        else:
            return np.zeros(self.vector_size)

    def fit(self, X):
        self.w2v = Word2Vec(X, vector_size=self.vector_size)
        return self

    def transform(self, X):
        X_vec=[]
        for sentence in X:
            X_vec.append(self.sent_vectorizer(sentence, self.w2v))
        return X_vec

This code currently does well in training but returns zeroed vectors on inference (because no word has been recognized).

Most likely problem: fit method is not properly storing self.w2v, although when transform is called it seems to exist.


Solution

  • Turns out I had an outdated gensim version which required vectorizer[word] instead of vectorizer.wv[word]. I'll leave the question here as it might be usefull to someone.