Search code examples
python-2.7typesscikit-learnrandom-forestgrid-search

Scikit: how to check if an object is a RandomizedSearchCV or a RandomForestClassifier?


I have a few classifiers that have been created using Grid Search, and others that have been created directly as Random Forests.

The random forests return type sklearn.ensemble.forest.RandomForestClassifier, and the random forests created with gridSearch return type sklearn.grid_search.RandomizedSearchCV.

I am trying to programmatically check the type of the estimator (in order to decide if I need to use best_estimator_ on feature importances), but can't seem to find a good way to do so.

if type(estimator) == 'sklearn.grid_search.RandomizedSearchCV' was my first guess, but is clearly wrong.


Solution

  • The type() function doesn't return a classinfo, it returns a type object. So comparing equality to a classinfo like that won't work.

    What you need to do is use isinstance(object, classinfo) to test the type of your estimator.

    This function returns True if the type matches the classinfo and False if it doesn't.

    Let's say you created an estimator of type

    sklearn.ensemble.forest.RandomForestClassifier

    Then

    isinstance(estimator, sklearn.ensemble.forest.RandomForestClassifier)

    would return True, while

    isinstance(estimator,sklearn.grid_search.RandomizedSearchCV)

    would return False.

    You could then use that result in tests such as if statements.

    Remember to

    import sklearn

    to have access to all the scikit-learn classinfo you might need to test against.