Search code examples
pythonscikit-learncatboost

Cross-validation with CatBoostRegressor never stop


I use this code to do Cross-validation with catboost.However, it has been 10 hours, and the console is still output, and the cross-validation is obviously more than 5 rounds.
What is the problem?

import pandas as pd
from sklearn.model_selection import train_test_split
import catboost
# from sklearn.model_selection import KFold
from sklearn.feature_selection import RFECV

train_data = pd.read_csv('train.txt',sep='\t')
test_data = pd.read_csv('test.txt',sep='\t')
X = train_data.iloc[:,:-1]
y = train_data['target']
model = catboost.CatBoostRegressor(
                           loss_function="RMSE",
                           eval_metric="RMSE",
                           task_type="GPU",
                           learning_rate=0.01,
                           iterations=10000,
                           random_seed=42,
                           od_type="Iter",
                           depth=10,
                           early_stopping_rounds=50
                          )
rfecv = RFECV(estimator = model,cv = 5,scoring = 'neg_mean_squared_error')
rfecv.fit(X, y)
df = pd.DataFrame(rfecv.predict(test_data))
df.to_csv("my.txt", index=False, header=False)

Then I stop the program,when I change iterations to 100,I found that the console has iterated a total of 161 times and stopped.In my opinion, it should be stopped after 5 cross-validation, but in fact it has not stopped after 5 times.
Why is it like this?


Solution

  • There are several easy steps in trying to remedy this problem:

    1. Reduce iterations to a more meaningful 1000, or even 300 if you still have time issues.
    2. Reduce cv to 3 eg
    3. Check with a smaller subset of your data if it finishes within a reasonable amount of time. Then estimate with the whole dataset -- assuming the algo will scale "linearly" -- if you're comfortable with the time.

    Notice, RFECV is a greedy algo, pruning weakest features one at a time, which may lead to a problem if a local minimum is reached.