Search code examples
pythonscikit-learnclassificationadaboost

LDA as base learner for AdaBoost in Python


I'm working on multi-class classification using AdaBoost, with the base learner as a discriminant (linear or quadratic). I couldn't find any functionality in scikit-learn or any other library to implement this, how do I go about this?


Solution

  • Although scikit-learn's AdaBoostClassifier allows for a base estimator of your choosing (see documentation), it requires the estimator to support sample_weight. Take a look at the source:

    if not has_fit_parameter(self.base_estimator_, "sample_weight"):
        raise ValueError("%s doesn't support sample_weight."
                         % self.base_estimator_.__class__.__name__)
    

    Unfortunately, neither LinearDiscriminantAnalysis nor QuadraticDiscriminantAnalysis fall into this category. Here's a toy example:

    from sklearn.ensemble import AdaBoostClassifier
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
    from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    
    iris = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target)
    
    clf = AdaBoostClassifier(base_estimator=LDA())
    clf.fit(X_train, y_train)
    

    You'll see a traceback like the following:

    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "//anaconda/lib/python2.7/site-packages/sklearn/ensemble/weight_boosting.py", line 411, in fit
        return super(AdaBoostClassifier, self).fit(X, y, sample_weight)
      File "//anaconda/lib/python2.7/site-packages/sklearn/ensemble/weight_boosting.py", line 128, in fit
        self._validate_estimator()
      File "//anaconda/lib/python2.7/site-packages/sklearn/ensemble/weight_boosting.py", line 429, in _validate_estimator
        % self.base_estimator_.__class__.__name__)
    ValueError: LinearDiscriminantAnalysis doesn't support sample_weight.
    

    This is a requirement you're not going to get around in scikit-learn. The documentation makes it clear that it's a hard requirement:

    "...Support for sample weighting is required, as well as proper classes_ and n_classes_ attributes."

    However, if your desire is simply to use an ensemble, you could always use bagging rather than boosting:

    from sklearn.ensemble import BaggingClassifier
    clf = BaggingClassifier(base_estimator=LDA())
    clf.fit(X_train, y_train)