Search code examples
pythonscikit-learnlightgbm

Custom class that inherits LGBMClassifier doesn't work: KeyError: 'random_state'


I create a random dataset to train a LGBM model:

from sklearn.datasets import make_classification

X, y = make_classification()

Then I train and predict the original LGBM model with no issues:

from lightgbm import LGBMClassifier

clf = LGBMClassifier()

clf.fit(X, y=y)
clf.predict(X)
clf.predict_proba(X)

But when I create a custom class of LGBMClassifier, I get an error:

class MyClf(LGBMClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def fit(self, X, y=None):
        return super().fit(X, y=y)

    def predict(self, X):
        return super().predict(X)

    def predict_proba(self, X):
        return super().predict_proba(X)
    
clf = MyClf()
clf.fit(X, y=y)
clf.predict(X)
clf.predict_proba(X)

In clf.fit:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[15], line 15
     12         return super().predict_proba(X)
     14 clf = MyClf()
---> 15 clf.fit(X, y=y)
     16 clf.predict(X)
     17 clf.predict_proba(X)

Cell In[15], line 6
      5 def fit(self, X, y=None):
----> 6     return super().fit(X, y=y)

File lib/python3.9/site-packages/lightgbm/sklearn.py:890, in LGBMClassifier.fit(self, X, y, sample_weight, init_score, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
    887         else:
    888             valid_sets[i] = (valid_x, self._le.transform(valid_y))
--> 890 super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets,
    891             eval_names=eval_names, eval_sample_weight=eval_sample_weight,
    892             eval_class_weight=eval_class_weight, eval_init_score=eval_init_score,
    893             eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds,
    894             verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature,
    895             callbacks=callbacks, init_model=init_model)
    896 return self

File lib/python3.9/site-packages/lightgbm/sklearn.py:570, in LGBMModel.fit(self, X, y, sample_weight, init_score, group, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_group, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
    568 params.pop('n_estimators', None)
    569 params.pop('class_weight', None)
--> 570 if isinstance(params['random_state'], np.random.RandomState):
    571     params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
    572 for alias in _ConfigAliases.get('objective'):

KeyError: 'random_state'

I couldn't find the issue even I have inspected the source code of LGBMClassifier.


Solution

  • Apparently, sklearn uses __init__s signature (arguments list) to build some param dictionary along the way. So when you override your __init__ it loses some of the entries in the param. One quick fix I can think of is to copy the arguments to your class:

    class MyClf(LGBMClassifier):
        def __init__(
            self,
            boosting_type: str = 'gbdt',
            num_leaves: int = 31,
            max_depth: int = -1,
            learning_rate: float = 0.1,
            n_estimators: int = 100,
            subsample_for_bin: int = 200000,
            objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
            class_weight: Optional[Union[Dict, str]] = None,
            min_split_gain: float = 0.,
            min_child_weight: float = 1e-3,
            min_child_samples: int = 20,
            subsample: float = 1.,
            subsample_freq: int = 0,
            colsample_bytree: float = 1.,
            reg_alpha: float = 0.,
            reg_lambda: float = 0.,
            random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
            n_jobs: Optional[int] = None,
            importance_type: str = 'split',
            **kwargs
        ):
            super().__init__(
                boosting_type = boosting_type,
                num_leaves = num_leaves,
                max_depth = max_depth,
                learning_rate = learning_rate,
                n_estimators = n_estimators,
                subsample_for_bin = subsample_for_bin,
                objective = objective,
                class_weight = class_weight,
                min_split_gain = min_split_gain,
                min_child_weight = min_child_weight,
                min_child_samples = min_child_samples,
                subsample = subsample,
                subsample_freq = subsample_freq,
                colsample_bytree = colsample_bytree,
                reg_alpha = reg_alpha,
                reg_lambda = reg_lambda,
                random_state = random_state,
                n_jobs = n_jobs,
                importance_type = importance_type,
                **kwargs)