Search code examples
lightgbm

Why does this simple LightGBM binary classifier perform poorly?


I tried to train a LightGBM binary classifier using the Python API the relation - if feature > 5, then 1 else 0

import pandas as pd
import numpy as np
import lightgbm as lgb
x_train = pd.DataFrame([4, 7, 2, 6, 3, 1, 9])
y_train = pd.DataFrame([0, 1, 0, 1, 0, 0, 1])
x_test = pd.DataFrame([8, 2])
y_test = pd.DataFrame([1, 0])
lgb_train = lgb.Dataset(x_train, y_train)
lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
params = { 'objective': 'binary', 'metric': {'binary_logloss', 'auc'}}
gbm = lgb.train(params, lgb_train, valid_sets=lgb_eval)
y_pred = gbm.predict(x_test, num_iteration=gbm.best_iteration)

y_pred
array([0.42857143, 0.42857143])

np.where((y_pred > 0.5), 1, 0)
array([0, 0])

Clearly it failed to predict the first test 8. Can anyone see what went wrong?


Solution

  • LightGBM's parameter defaults are set with the expectation of moderate-sized training data, and might not work well on extremely small datasets like the one in this question.

    There are two in particular that are impacting your result:

    • min_data_in_leaf: minimum number of samples that must fall into a leaf node
    • min_sum_hessian_in_leaf: basically, the minimum contribution to the loss function for one leaf node

    Setting these to the lowest possible values can force LightGBM to overfit to such a small dataset.

    import pandas as pd
    import numpy as np
    import lightgbm as lgb
    
    x_train = pd.DataFrame([4, 7, 2, 6, 3, 1, 9])
    y_train = pd.DataFrame([0, 1, 0, 1, 0, 0, 1])
    x_test = pd.DataFrame([8, 2])
    y_test = pd.DataFrame([1, 0])
    
    lgb_train = lgb.Dataset(x_train, y_train)
    lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
    
    params = {
        'objective': 'binary',
        'metric': {'binary_logloss', 'auc'},
        'min_data_in_leaf': 1,
        'min_sum_hessian_in_leaf': 0
    }
    gbm = lgb.train(params, lgb_train, valid_sets=lgb_eval)
    y_pred = gbm.predict(x_test, num_iteration=gbm.best_iteration)
    
    y_pred
    # array([6.66660313e-01, 1.89048958e-05])
    
    np.where((y_pred > 0.5), 1, 0)
    # array([1, 0])
    

    For details on all the parameters and their defaults, see https://lightgbm.readthedocs.io/en/latest/Parameters.html.