Search code examples
machine-learningscikit-learnmultilabel-classificationlightgbm

callbacks in sklearn.multiclass.OneVsRestClassifier


I want use callbacks and eval_set etc. but i have a problem:

from sklearn.multiclass import OneVsRestClassifier
import lightgbm
verbose = 100
params = {
    "objective": "binary",
    "n_estimators": 500,
    "verbose": 0
}
fit_params = {
    "eval_set": eval_dataset,
    "callbacks": [CustomCallback(verbose)]
}

clf = OneVsRestClassifier(lightgbm.LGBMClassifier(**params))
clf.fit(X_train, y_train,  **fit_params)

how i can hand over fit_params to my estimator? I get

----------------------------------------------------------------------
---> 13 clf.fit(X_train, y_train,  **fit_params)

TypeError: OneVsRestClassifier.fit() got an unexpected keyword argument 'eval_set'

Solution

  • Per scikit-learn's docs for OneVsRestClassifier (link), as of v1.4.0 additional **fit_params are only passed through to estimators' fit() methods if you've enabled what scikit-learn calls "metadata routing".

    There are 2 required steps which are missing in your example:

    • opting in by running sklearn.set_config(enable_metadata_routing=True)
    • explicitly telling scikit-learn to pass through eval_set and callbacks, via .set_fit_request().

    (docs link)

    Consider this minimal, reproducible example using Python 3.11, lightgbm==4.3.0, and scikit-learn==1.4.1.

    import lightgbm as lgb
    import sklearn
    from sklearn.datasets import make_blobs
    from sklearn.multiclass import OneVsRestClassifier
    
    # enable metadata_routing
    sklearn.set_config(enable_metadata_routing=True)
    
    # create datasets
    X, y = sklearn.datasets.make_blobs(
        n_samples=10_000,
        n_features=10,
        centers=2
    )
    eval_dataset = lgb.Dataset(X, label=y)
    eval_results = {}
    
    # construct estimator
    params = {
        "objective": "binary",
        "n_estimators": 10,
    }
    fit_params = {
        "eval_set": (X, y),
        "callbacks": [lgb.record_evaluation(eval_results)]
    }
    
    clf = OneVsRestClassifier(
        lgb.LGBMClassifier(**params)
        .set_fit_request(callbacks=True, eval_set=True)
    )
    
    # train
    clf.fit(X, y,  **fit_params)
    
    # check eval results, to prove that the callback was used
    print(eval_results)
    
    # {'valid_0': OrderedDict([('binary_logloss', [0.598138869381609, 0.5203293282602738, 0.45544446427154844, 0.40059849184355334, 0.3537472248673818, 0.31338812592304066, 0.2783839141567028, 0.24785302530927006, 0.22109850424011224, 0.19756016345789282])])}