Save and Load a catboost model with Caret in R

I am able to train a Catboost model with caret (in Rstudio) and it works great.

my_catboost <- caret::train(x, y, 

              tuneGrid = param,
              metric = "ROC")

If I use the model to predict on new data in the same session, no issue, it works:

output <- caret::predict.train(my_catboost, newdata=x_testing, type="prob")

However, If I save the model and load it later (or save it, delete "my_catboost" and load), the function predict will crash R and Rstudio without error message and can't find anything in Rstudio log. After the load, I can see the model being created in the global Environment and it seems fine.

I tried the R function save and load, saveRDS and readRDS and both crashed

Thanks !


  • You have misunderstood my comment. Here is an answer using the inbuilt data set Sonar:


    create train and test data sets:

    tr <- createDataPartition(Sonar$Class, p = 0.7, list = FALSE)
    trainer <- Sonar[tr,]
    tester <- Sonar[-tr,]

    train models:

    fitControl <- trainControl(method = "cv",
                               number = 3,
                               savePredictions = TRUE,
                               summaryFunction = twoClassSummary,
                               classProbs = TRUE)
    model <- train(x = trainer[,1:60],
                   y = trainer$Class,
                   method = catboost.caret, 
                   trControl = fitControl, 
                   tuneLength = 5,
                   metric = "ROC")

    predict using caret:

    preds1 <- predict(model, tester, type = "prob")

    save the final model:

    catboost::catboost.save_model(model$finalModel, "model")

    load the saved model:

    model2 <- catboost::catboost.load_model("model")

    predict using the saved model:

    preds2 <- catboost.predict(model2,
                               prediction_type = "Probability")

    check equality of predictions

    all.equal(preds1[,2], preds2)

    EDIT: while:

    saveRDS(model, "caret.model.rds")
    model3 <- readRDS("caret.model.rds")
    preds3 <- predict(model3, tester, type = "prob")

    results in R session crash

