Search code examples
pythonmachine-learninglightgbmboosting

How many trees do I actually have in my LightGBM model?


I have code that looks like this

clf = lgb.LGBMClassifier(max_depth=3, verbosity=-1, n_estimators=3)
clf.fit(train_data[features], train_data['y'], sample_weight=train_data['weight'])
print (f"I have {clf.n_estimators_} estimators")
fig, ax = plt.subplots(nrows=4, figsize=(50,36), sharex=True)
lgb.plot_tree(clf, tree_index=7, dpi=600, ax=ax[0]) # why does it have 7th tree?
lgb.plot_tree(clf, tree_index=8, dpi=600, ax=ax[1]) # why does it have 8th tree?
#lgb.plot_tree(clf, tree_index=9, dpi=600, ax=ax[2]) # crashes
#lgb.plot_tree(clf, tree_index=10, dpi=600, ax=ax[3]) # crashes

I am surprised that despite n_estimators=3, I seem to have 9 trees? How do I actually set the number of trees, and related to that, what does n_estimators do? I've read the docs, and I thought it would be the number of trees, but it seems to be something else.

Separately, how do I interpret the separate trees, with their ordering, 0, 1, 2, etc. I know random forest, and how there every tree is equally important. In boosting, the first tree is most important, the next one significantly less, the next significantly less. So in my head, when I look at the tree diagrams, how can I "simulate" the LightGBM inference process?


Solution

  • How do I actually set the number of trees, and related to that, what does n_estimators do?

    Pass n_estimators or one of its aliases (LightGBM docs).

    n_estimators in LightGBM's scikit-learn interface (classes like LGBMClassifier) controls the number of boosting rounds.

    For all tasks other than multiclass classification, LightGBM will produce 1 tree per boosting round.

    For multiclass classification, LightGBM will train 1 tree per class in each boosting round.

    So, for example, if your target has 5 classes, then training with n_estimators=3 and no early stopping will produce 15 trees.

    how do I interpret the separate trees, with their ordering, 0, 1, 2, etc... how can I "simulate" the LightGBM inference process?

    Each consecutive grouping of {num_classes} trees corresponds to one boosting round. They are ordered by target class.

    Given an input X, LightGBM's prediction that X belongs to class i will be given by:

    tree_{i}(X) +
    tree_{i+num_classes}(X) +
    tree_{i+num_classes*2}(X)
    ... etc/, etc.
    

    So, for example, consider 5-class multiclass classification and 3 boosting rounds, using LightGBM's built-in multiclass objective.

    LightGBM's score for a sample x belonging to the first class will be the sum of the corresponding leaf values from the 1st, 6th, and 11th trees.

    Here's a minimal, reproducible example with lightgbm==4.3.0 and Python 3.11.

    import lightgbm as lgb
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import make_blobs
    
    # generate multiclass dataset with 5 classes
    X, y = make_blobs(n_samples=1_000, centers=5, random_state=773)
    
    # fit a small multiclass classification model
    clf = lgb.LGBMClassifier(n_estimators=3, num_leaves=4, seed=708)
    clf.fit(X, y)
    
    # underlying model has 15 trees
    clf.booster_.num_trees()
    # 15
    
    # but just 3 iterations (boosting rounds)
    clf.n_iter_
    # 3
    
    # just plot the trees for the first class
    lgb.plot_tree(clf, tree_index=0)
    lgb.plot_tree(clf, tree_index=5)
    lgb.plot_tree(clf, tree_index=10)
    plt.show()
    

    This will produce tree diagrams that look like this:

    enter image description here

    Try generating the raw predictions for the first row of the training data.

    clf.predict(X, raw_score=True)[0,]
    # array([-1.24209749, -1.90204682, -1.9020346 , -1.89711144, -1.23250193])
    

    You could manually calculate which leaf node that sample belongs to in each tree and add those leaf values. That number should match the first item in the raw score above (in this example, -1.24209749).

    If you have pandas available, you might find it easier to dump the tree structure to a dataframe and work with it there.

    model_df = clf.booster_.trees_to_dataframe()
    
    # trees relevant to class 0
    relevant_trees = [0, 5, 10]
    
    # figure out which leaf index each sample falls into
    leaf_preds = clf.predict(X, pred_leaf=True)[0,]
    
    # subset that to only the trees relevant to class 0
    relevant_leaf_ids = [
        f"0-L{leaf_preds[0]}",
        f"5-L{leaf_preds[5]}",
        f"10-L{leaf_preds[10]}"
    ]
    
    # show the values LightGBM would predict from each tree
    model_df[
      model_df["tree_index"].isin(relevant_trees) &
      model_df["node_index"].isin(relevant_leaf_ids)
    ][["tree_index", "node_index", "value"]]
    
        tree_index node_index     value
    5            0       0-L3 -1.460720
    38           5       5-L3  0.119902
    73          10      10-L3  0.098720       
    

    Those 3 values add to -1.242098, almost identical to the score predicted by clf.predict(X, raw_score=True) (just different by numerical precision lost from printing).