I create a random dataset to train a LGBM
model:
from sklearn.datasets import make_classification
X, y = make_classification()
Then I train and predict the original LGBM model with no issues:
from lightgbm import LGBMClassifier
clf = LGBMClassifier()
clf.fit(X, y=y)
clf.predict(X)
clf.predict_proba(X)
But when I create a custom class of LGBMClassifier
, I get an error:
class MyClf(LGBMClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def fit(self, X, y=None):
return super().fit(X, y=y)
def predict(self, X):
return super().predict(X)
def predict_proba(self, X):
return super().predict_proba(X)
clf = MyClf()
clf.fit(X, y=y)
clf.predict(X)
clf.predict_proba(X)
In clf.fit
:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[15], line 15
12 return super().predict_proba(X)
14 clf = MyClf()
---> 15 clf.fit(X, y=y)
16 clf.predict(X)
17 clf.predict_proba(X)
Cell In[15], line 6
5 def fit(self, X, y=None):
----> 6 return super().fit(X, y=y)
File lib/python3.9/site-packages/lightgbm/sklearn.py:890, in LGBMClassifier.fit(self, X, y, sample_weight, init_score, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
887 else:
888 valid_sets[i] = (valid_x, self._le.transform(valid_y))
--> 890 super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets,
891 eval_names=eval_names, eval_sample_weight=eval_sample_weight,
892 eval_class_weight=eval_class_weight, eval_init_score=eval_init_score,
893 eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds,
894 verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature,
895 callbacks=callbacks, init_model=init_model)
896 return self
File lib/python3.9/site-packages/lightgbm/sklearn.py:570, in LGBMModel.fit(self, X, y, sample_weight, init_score, group, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_group, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
568 params.pop('n_estimators', None)
569 params.pop('class_weight', None)
--> 570 if isinstance(params['random_state'], np.random.RandomState):
571 params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
572 for alias in _ConfigAliases.get('objective'):
KeyError: 'random_state'
I couldn't find the issue even I have inspected the source code of LGBMClassifier
.
Apparently, sklearn uses __init__
s signature (arguments list) to build some param
dictionary along the way. So when you override your __init__
it loses some of the entries in the param
. One quick fix I can think of is to copy the arguments to your class:
class MyClf(LGBMClassifier):
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
**kwargs
):
super().__init__(
boosting_type = boosting_type,
num_leaves = num_leaves,
max_depth = max_depth,
learning_rate = learning_rate,
n_estimators = n_estimators,
subsample_for_bin = subsample_for_bin,
objective = objective,
class_weight = class_weight,
min_split_gain = min_split_gain,
min_child_weight = min_child_weight,
min_child_samples = min_child_samples,
subsample = subsample,
subsample_freq = subsample_freq,
colsample_bytree = colsample_bytree,
reg_alpha = reg_alpha,
reg_lambda = reg_lambda,
random_state = random_state,
n_jobs = n_jobs,
importance_type = importance_type,
**kwargs)