Search code examples
pythonscikit-learnclassificationimbalanced-data

What SKLearn classifiers come with class_weight parameter


Working on an imbalanced project I was wondering what classifiers come with a class_weigth parameter out of the box.

Having been inspired by:

from sklearn.utils.testing import all_estimators

estimators = all_estimators()

for name, class_ in estimators:
     if hasattr(class_, 'predict_proba'):
     print(name) 

'compute_class_weight' is a function and not a class. So essentially I am looking for a snippet that prints any classifier that calls for compute_class_weight (to be 'balanced':-) function.


Solution

  • You can get the classifiers (not all estimators) and check for class_weight attribute in the instantiated objects:

    from sklearn.utils.testing import all_estimators
    
    estimators = all_estimators(type_filter='classifier')
    for name, class_ in estimators:
        if hasattr(class_(), 'class_weight'): # Note the parenthesis: class_() 
            print(name)
    

    Generates the list of the classifiers that can handle class imbalance:

    DecisionTreeClassifier
    ExtraTreeClassifier
    ExtraTreesClassifier
    LinearSVC
    LogisticRegression
    LogisticRegressionCV
    NuSVC
    PassiveAggressiveClassifier
    Perceptron
    RandomForestClassifier
    RidgeClassifier
    RidgeClassifierCV
    SGDClassifier
    SVC
    

    Note that class_weight is an attribute of the instantiated models and not of the classes of the models. The class LogisticRegression doesn't have class_weight, but a model of type LogisticRegression does. This is the basic Object-Oriented distiction between an instance and a class. You can check the difference practically with this code:

    from sklearn.linear_model import LogisticRegression
    
    logreg_class = LogisticRegression
    print(type(logreg_class))
    # >>> <class 'type'>
    
    logreg_model = LogisticRegression()
    print(type(logreg_model))
    # >>> <class 'sklearn.linear_model.logistic.LogisticRegression'>
    

    During the loop, class_ refers to the model class and class_() is a call to the constructor of that class, which returns an instance.