Search code examples
pythonscikit-learnpicklecross-validation

Saving a cross-validation trained model in Scikit


I have trained a model in scikit-learn using Cross-Validation and Naive Bayes classifier. How can I persist this model to later run against new instances?

Here is simply what I have, I can get the CV scores but I don't know how to have access to the trained model

gnb = GaussianNB() 
scores = cross_validation.cross_val_score(gnb, data_numpy[0],data_numpy[1], cv=10)

Solution

  • cross_val_score doesn't changes your estimator, and it will not return fitted estimator. It just returns score of estimator of cross validation.

    To fit your estimator - you should call fit on it explicitly with provided dataset. To save (serialize) it - you can use pickle:

    # To fit your estimator
    gnb.fit(data_numpy[0], data_numpy[1])
    # To serialize
    import pickle
    with open('our_estimator.pkl', 'wb') as fid:
        pickle.dump(gnb, fid)
    # To deserialize estimator later
    with open('our_estimator.pkl', 'rb') as fid:
        gnb = pickle.load(fid)