Im trying to find a way to train a lightgbm model forcing to have some features to be in the splits, i.e.: "to be in the feature importance", then the predictions are afected by these variables.
Here is an example of a the modeling code with an usless variable as it is constant, but the idea is that there could be an important variable from business perspective that is not in the feature
from lightgbm import LGBMRegressor
import pandas as pd
import numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
# Generar un dataset de regresión aleatorio
X, y = make_regression(n_samples=1000, n_features=10, noise=0.9, random_state=42)
feature_names = [f"feature_{i}" for i in range(X.shape[1])]
# Convertir a DataFrame para mayor legibilidad
X = pd.DataFrame(X, columns=feature_names)
# Agregar características inútiles
X["useless_feature_1"] = 1
# Dividir los datos en conjuntos de entrenamiento y prueba
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Definir el modelo LGBMRegressor
model = LGBMRegressor(
objective="regression",
metric="rmse",
random_state=1,
n_estimators=100
)
# Entrenar el modelo
model.fit(X_train, y_train, eval_set=[(X_test, y_test)])
# Predicciones y evaluación
y_pred = model.predict(X_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print(f"Test RMSE: {rmse:.4f}")
# Importancia de características
importance = pd.DataFrame({
"feature": X.columns,
"importance": model.feature_importances_
}).sort_values(by="importance", ascending=False)
print("\nFeature Importance:")
print(importance)
Expected solution: There should be some workarround, but the most interesting one would be the one that is using some param in the fit or the regressor method.
As of this writing, LightGBM does not have functionality like "force at least 1 split on a given feature, but let LightGBM choose the threshold".
However, it is possible to force LightGBM to split on specific features with specific thresholds.
Here's an example (I tested it with lightgbm
4.5.0):
import json
import lightgbm as lgb
import numpy as np
from sklearn.datasets import make_regression
X, y = make_regression(
n_samples=10_000,
n_features=5,
n_informative=5,
random_state=42
)
# add a noise feature
noise_feature = np.random.random(size=(X.shape[0], 1))
X = np.concatenate((X, noise_feature), axis=1)
# train a small model
model1 = lgb.LGBMRegressor(
random_state=708,
n_estimators=10,
)
model1.fit(X, y)
# notice: that noise feature (the 6th one) was never chosen for a split
model1.feature_importances_
# array([ 0, 97, 110, 0, 93, 0], dtype=int32)
# force the use of that noise feature in every tree
forced_split = {
"feature": 5,
"threshold": np.mean(noise_feature),
}
with open("forced_splits.json", "w") as f:
f.write(json.dumps(forced_split))
# train another model, forcing it to use those splits
model2 = lgb.LGBMRegressor(
random_state=708,
n_estimators=10,
forcedsplits_filename="forced_splits.json",
)
model2.fit(X, y)
# noise feature was used once in every tree
model2.feature_importances_
# array([ 0, 104, 131, 0, 55, 10], dtype=int32)
That JSON file defining the splits can be extended with arbitrarily deep nesting. (LightGBM docs)
For example, here's how to force it to use the 6th, 1st, and 4th features (in that order), split on their means, all down the left side of each tree.
forced_split = {
"feature": 5,
"threshold": np.mean(noise_feature),
"left": {
"feature": 0,
"threshold": np.mean(X[:,0]),
"left": {
"feature": 3,
"threshold": np.mean(X[:,2]),
}
}
}
with open("forced_splits.json", "w") as f:
f.write(json.dumps(forced_split))
model3 = lgb.LGBMRegressor(
random_state=708,
n_estimators=10,
forcedsplits_filename="forced_splits.json",
).fit(X,y)
model3.feature_importances_
# array([ 10, 114, 133, 10, 23, 10], dtype=int32)
If you don't want the same structure for every tree, you could look into using "training continuation", changing this parameter for each batch of training rounds. See LightGBM: train() vs update() vs refit().