I tried to train a LightGBM binary classifier using the Python API the relation - if feature > 5, then 1 else 0
import pandas as pd
import numpy as np
import lightgbm as lgb
x_train = pd.DataFrame([4, 7, 2, 6, 3, 1, 9])
y_train = pd.DataFrame([0, 1, 0, 1, 0, 0, 1])
x_test = pd.DataFrame([8, 2])
y_test = pd.DataFrame([1, 0])
lgb_train = lgb.Dataset(x_train, y_train)
lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
params = { 'objective': 'binary', 'metric': {'binary_logloss', 'auc'}}
gbm = lgb.train(params, lgb_train, valid_sets=lgb_eval)
y_pred = gbm.predict(x_test, num_iteration=gbm.best_iteration)
y_pred
array([0.42857143, 0.42857143])
np.where((y_pred > 0.5), 1, 0)
array([0, 0])
Clearly it failed to predict the first test 8
. Can anyone see what went wrong?
LightGBM's parameter defaults are set with the expectation of moderate-sized training data, and might not work well on extremely small datasets like the one in this question.
There are two in particular that are impacting your result:
min_data_in_leaf
: minimum number of samples that must fall into a leaf nodemin_sum_hessian_in_leaf
: basically, the minimum contribution to the loss function for one leaf nodeSetting these to the lowest possible values can force LightGBM to overfit to such a small dataset.
import pandas as pd
import numpy as np
import lightgbm as lgb
x_train = pd.DataFrame([4, 7, 2, 6, 3, 1, 9])
y_train = pd.DataFrame([0, 1, 0, 1, 0, 0, 1])
x_test = pd.DataFrame([8, 2])
y_test = pd.DataFrame([1, 0])
lgb_train = lgb.Dataset(x_train, y_train)
lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
params = {
'objective': 'binary',
'metric': {'binary_logloss', 'auc'},
'min_data_in_leaf': 1,
'min_sum_hessian_in_leaf': 0
}
gbm = lgb.train(params, lgb_train, valid_sets=lgb_eval)
y_pred = gbm.predict(x_test, num_iteration=gbm.best_iteration)
y_pred
# array([6.66660313e-01, 1.89048958e-05])
np.where((y_pred > 0.5), 1, 0)
# array([1, 0])
For details on all the parameters and their defaults, see https://lightgbm.readthedocs.io/en/latest/Parameters.html.