Search code examples
machine-learningscikit-learnnaivebayes

Huge number of classes with Multinominal Naive Bayes (scikit-learn)


Whenever I start having a bigger number of classes (1000 and more) MultinominalNB gets super slow and takes Gigabytes of RAM. The same is true for all the scikit learn classification algorithms that support .partial_fit() (SGDClassifier, Perceptron). When working with convolutional neural networks 10000 classes are no problem. But when I want to train MultinominalNB on the same data my 12GB of RAM are not enough and it is very very slow. From my understanding of Naive Bayes, even with a lot of classes, it should be a lot faster. Might this be a problem of the scikit-learn implementation (maybe of the .partial_fit() function) ? How can I train MultinominalNB/SGDClassifier/Perceptron on 10000+ classes (batchwise)?


Solution

  • Short answer without much information:

    • The MultinomialNB fits an independent model to each of the classes, thus, if you have C=10000+ classes it will fit C=10000+ models and therefore, only the model parameters will be [n_classes x n_features], which is quite a lot of memory if n_features is large.

    • The SGDClassifier of scikits-learn uses OVA (one-versus-all) strategy to train a multiclass model (as the SGDC is not inherently multiclass) and therefore, another C=10000+ models need to be trained.

    • And Perceptron, from the documentation of scikits-learn:

    Perceptron and SGDClassifier share the same underlying implementation. In fact, Perceptron() is equivalent to SGDClassifier(loss=”perceptron”, eta0=1, learning_rate=”constant”, penalty=None).

    So, all the 3 classifiers you mention don't work well with high number of classes, as an independent model needs to be trained for each of the classes. I would recommend you to try something that inherently support multiclass classification, such as RandomForestClassifier.