Search code examples
pythonmachine-learningtext-classificationmulticlass-classification

Saving a trained multi-input classification algorithm in Python


I developed a script that predicts probable tags for some text, based on previously manually tagged feedback. I used several online articles to help me (namely: https://towardsdatascience.com/multi-label-text-classification-with-scikit-learn-30714b7819c5).

Because I want the probability for each tag, here's the code I used:

NB_pipeline = Pipeline([
    ('clf', OneVsRestClassifier(MultinomialNB(alpha=0.3, fit_prior=True, class_prior=None))),
    ])

predictions_en = {}
for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

preds_en = pd.DataFrame(predictions_en.items())
preds_en = preds_en.sort_values(by=[1], ascending=False)
preds_en = preds_en.reset_index(drop=True)

It works very well for my purposes: it returns a prediction for each possible tag. But my issue is that it retrains the algorithm every time I try to make a prediction. What I'd like to do is to train the algorithm in a script, save the trained algorithm, load it in another script where the prediction is made.

I'd like to be able to do this in script 1:

for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])

And this in the other script:

for category in categories_en:
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

But I can't seem to make it work. It just gives me the same prediction when I try to separate it.


Solution

  • You could always use pickle to serialize any python object including yours. So the simplest and fastest way to save your model is to just serialize it to a file, say model.pickle. This is done in the first part after you train your model. After that, all you have to do is to check if the file exists and deserialize it using pickle again.

    This is a function that serializes python objects to files:

    import pickle
    
    def serialize(obj, file):
        with open(file, 'wb') as f:
            pickle.dump(obj, f)
    
    

    This is a function that deserializes python objects from files:

    import pickle
    
    def deserialize(file):
        with open(file, 'rb') as f:
            return pickle.load(f)
    

    After your done training, all you have to do is to call (if NB_pipeline is the object of your model):

    serialize(NB_pipeline, 'model.pickle')
    

    And when you have to load it and use it just call:

    NB_pipeline = deserialize('model.pickle')