Search code examples
scikit-learnnaivebayesgrid-search

Scikit-Learn RandomizedSearchCV not working for class_prior in MultinomialNB


I am trying to do Randomized Parameter Optimization on a MultinomialNB (1). Now my parameter has 3 and not one value, as it is 'class_prior' and I do have 3 classes.

from sklearn.naive_bayes import MultinomialNB
from sklearn.grid_search import RandomizedSearchCV
from scipy.stats import uniform

tuned_parameters = {'class_prior': [uniform.rvs(0,3), uniform.rvs(0,3), 
uniform.rvs(0,3)]}
clf = RandomizedSearchCV(MultinomialNB(), tuned_parameters, cv=3, 
scoring='f1_micro', n_iter=10)

However the error log looks like:

...
File "/home/mark/Virtualenvs/python3env2/lib/python3.5/site-
packages/sklearn/naive_bayes.py", line 607, in fit
self._update_class_log_prior(class_prior=class_prior)
File "/home/mark/Virtualenvs/python3env2/lib/python3.5/site-
packages/sklearn/naive_bayes.py", line 455, in _update_class_log_prior
if len(class_prior) != n_classes:
TypeError: object of type 'numpy.float64' has no len()

also tried removing the .rvs -->

TypeError: object of type 'rv_frozen' has no len()

Is it impossible to RandomizeSearch a variable that has 3 components, namely the 3 class_priors?

(1) http://scikit-learn.org/stable/modules/grid_search.html


Solution

  • Yes its possible. Do it like this:

    tuned_parameters = {'class_prior': [[uniform.rvs(0,3), uniform.rvs(0,3), 
    uniform.rvs(0,3)]]}
    

    Notice the extra square brackets around the values. The reason is that the parameters to be tuned by RandomizedSearchCV (or GridSearchCV for that matter) should be wrapped in a list, out of which a single element will be tried each time. The combination on elements which produces the highest score (or lowest in case of loss) will be kept.

    For example see this simple code for parameters tuning of SVC:

    parameters = {'kernel':['linear', 'rbf'], 'C':[1, 10]}
    

    This will be expanded into a permutation of total 4 values like below:

    Option1:- 'kernel':'linear', 'C':1
    Option2:- 'kernel':'linear', 'C':10
    Option3:- 'kernel':'rbf', 'C':1
    Option4:- 'kernel':'rbf', 'C':10
    

    This means that the estimator will be fit 4 times (each time using different option from above) and then the best estimator will be kept.

    In your case, according to the documentation of MultinomialNB class_prior is the array of probabilities of the classes.

    So ideally it should be expanded into like this:

    Option1: 'class_prior': [uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)]

    But in RandomizedSearhCV (which have no information about the type of class_prior), it will be expanded like:

    Option1: 'class_prior': uniform.rvs(0,3)
    Option2: 'class_prior': uniform.rvs(0,3)
    Option3: 'class_prior': uniform.rvs(0,3)
    

    Which will be then presented to MultinomialNB and since the output of uniform.rvs() is a float and not a list, it will not have the len() and hence the error.

    For this, you need to use double square brackets so that the correct expansion is done like below:

    Option1: 'class_prior': [uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)]
    

    But now also there's a problem. Since the expansion resulted in a single option, so obviously that will be selected whatever the score is (because we dont have any other choice).

    Also, RandomizedSearchCV will throw an error because you have specified n_iter=10 and number of choices should be more than that (in our case, its a single choice).

    So you need to alter your tuned_parameters like this:

    tuned_parameters = {'class_prior': [[uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)],
                                        [uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)], 
                                        [uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)],
                                        [uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)],
                                        [uniform.rvs(0,3), uniform.rvs(0,3), uniform.rvs(0,3)],
                                        ...
                                        ...
                                        ... ]}
    

    at least n_iter times.