I have code that looks like this
clf = lgb.LGBMClassifier(max_depth=3, verbosity=-1, n_estimators=3)
clf.fit(train_data[features], train_data['y'], sample_weight=train_data['weight'])
print (f"I have {clf.n_estimators_} estimators")
fig, ax = plt.subplots(nrows=4, figsize=(50,36), sharex=True)
lgb.plot_tree(clf, tree_index=7, dpi=600, ax=ax[0]) # why does it have 7th tree?
lgb.plot_tree(clf, tree_index=8, dpi=600, ax=ax[1]) # why does it have 8th tree?
#lgb.plot_tree(clf, tree_index=9, dpi=600, ax=ax[2]) # crashes
#lgb.plot_tree(clf, tree_index=10, dpi=600, ax=ax[3]) # crashes
I am surprised that despite n_estimators=3
, I seem to have 9 trees? How do I actually set the number of trees, and related to that, what does n_estimators
do? I've read the docs, and I thought it would be the number of trees, but it seems to be something else.
Separately, how do I interpret the separate trees, with their ordering, 0, 1, 2, etc. I know random forest, and how there every tree is equally important. In boosting, the first tree is most important, the next one significantly less, the next significantly less. So in my head, when I look at the tree diagrams, how can I "simulate" the LightGBM inference process?
How do I actually set the number of trees, and related to that, what does n_estimators do?
Pass n_estimators
or one of its aliases (LightGBM docs).
n_estimators
in LightGBM's scikit-learn
interface (classes like LGBMClassifier
) controls the number of boosting rounds.
For all tasks other than multiclass classification, LightGBM will produce 1 tree per boosting round.
For multiclass classification, LightGBM will train 1 tree per class in each boosting round.
So, for example, if your target has 5 classes, then training with n_estimators=3
and no early stopping will produce 15 trees.
how do I interpret the separate trees, with their ordering, 0, 1, 2, etc... how can I "simulate" the LightGBM inference process?
Each consecutive grouping of {num_classes}
trees corresponds to one boosting round. They are ordered by target class.
Given an input X
, LightGBM's prediction that X
belongs to class i
will be given by:
tree_{i}(X) +
tree_{i+num_classes}(X) +
tree_{i+num_classes*2}(X)
... etc/, etc.
So, for example, consider 5-class multiclass classification and 3 boosting rounds, using LightGBM's built-in multiclass objective.
LightGBM's score for a sample x
belonging to the first class will be the sum of the corresponding leaf values from the 1st, 6th, and 11th trees.
Here's a minimal, reproducible example with lightgbm==4.3.0
and Python 3.11.
import lightgbm as lgb
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
# generate multiclass dataset with 5 classes
X, y = make_blobs(n_samples=1_000, centers=5, random_state=773)
# fit a small multiclass classification model
clf = lgb.LGBMClassifier(n_estimators=3, num_leaves=4, seed=708)
clf.fit(X, y)
# underlying model has 15 trees
clf.booster_.num_trees()
# 15
# but just 3 iterations (boosting rounds)
clf.n_iter_
# 3
# just plot the trees for the first class
lgb.plot_tree(clf, tree_index=0)
lgb.plot_tree(clf, tree_index=5)
lgb.plot_tree(clf, tree_index=10)
plt.show()
This will produce tree diagrams that look like this:
Try generating the raw predictions for the first row of the training data.
clf.predict(X, raw_score=True)[0,]
# array([-1.24209749, -1.90204682, -1.9020346 , -1.89711144, -1.23250193])
You could manually calculate which leaf node that sample belongs to in each tree and add those leaf values. That number should match the first item in the raw score above (in this example, -1.24209749
).
If you have pandas
available, you might find it easier to dump the tree structure to a dataframe and work with it there.
model_df = clf.booster_.trees_to_dataframe()
# trees relevant to class 0
relevant_trees = [0, 5, 10]
# figure out which leaf index each sample falls into
leaf_preds = clf.predict(X, pred_leaf=True)[0,]
# subset that to only the trees relevant to class 0
relevant_leaf_ids = [
f"0-L{leaf_preds[0]}",
f"5-L{leaf_preds[5]}",
f"10-L{leaf_preds[10]}"
]
# show the values LightGBM would predict from each tree
model_df[
model_df["tree_index"].isin(relevant_trees) &
model_df["node_index"].isin(relevant_leaf_ids)
][["tree_index", "node_index", "value"]]
tree_index node_index value
5 0 0-L3 -1.460720
38 5 5-L3 0.119902
73 10 10-L3 0.098720
Those 3 values add to -1.242098
, almost identical to the score predicted by clf.predict(X, raw_score=True)
(just different by numerical precision lost from printing).