Search code examples
rtimeoutr-caret

Machine learning with caret: How to specify a timeout?


Is it possible to specify a timeout when training a model in R using trainfrom the caret library? If not, does a R construct exist that wraps the code and can be terminated after a certain amount of time?


Solution

  • Caret options are configured with the trainControl() object. It does not have a parameter to specify a timeout period.

    The two settings in trainControl() that make the most impact on runtime performance are method= and number=. The default method in caret is boot, or bootstrapping. The default number for the bootstrapping method is 25 unless method="cv".

    Therefore, a randomForest run with caret will conduct 25 iterations of bootstrap samples, a very slow process, especially if run on a single processor thread.

    Forcing a timeout

    R functions can be given a timeout period via the withTimeout() function from the R.utils package.

    For example, we'll run a random forest via caret with the mtcars data set, and execute 500 iterations of bootstrap sampling to get train() to run longer than 15 seconds. We will use withTimeout() to stop processing after 15 seconds of CPU time.

    data(mtcars)
    library(randomForest)
    library(R.utils)
    library(caret)
    fitControl <- trainControl(method = "boot",
                               number = 500,
                               allowParallel = FALSE)
    
    withTimeout(
         theModel <- train(mpg ~ .,data=mtcars,method="rf",trControl=fitControl)
         ,timeout=15)
    

    ...and the first part of the output:

    > withTimeout(
    +      theModel <- train(mpg ~ .,data=mtcars,method="rf",trControl=fitControl)
    +      ,timeout=15)
    [2018-05-19 07:32:37] TimeoutException: task 2 failed - "reached elapsed time limit" [cpu=15s, elapsed=15s]
    

    Improving caret performance

    Aside from simply timing out the caret::train() function, we can use two techniques to improve the performance of caret::train(), parallel processing and adjustments to the trainControl() parameters.

    1. Coding an R script to use parallel processing requires the parallel and doParallel() packages, and is a multi-step process.
    2. Changing method="boot" to method="cv" (k-fold cross validation) and reducing number= to 3 or 5 will significantly improve the runtime performance of caret::train().

    Summarizing techniques I previously described in Improving Performance of Random Forest with caret::train(), the following code uses the Sonar data set to implement parallel processing with caret and randomForest.

    #
    # Sonar example from caret documentation
    #
    
    library(mlbench)
    library(randomForest) # needed for varImpPlot
    data(Sonar)
    #
    # review distribution of Class column
    # 
    table(Sonar$Class)
    library(caret)
    set.seed(95014)
    
    # create training & testing data sets
    
    inTraining <- createDataPartition(Sonar$Class, p = .75, list=FALSE)
    training <- Sonar[inTraining,]
    testing <- Sonar[-inTraining,]
    
    #
    # Step 1: configure parallel processing
    # 
    
    library(parallel)
    library(doParallel)
    cluster <- makeCluster(detectCores() - 1) # convention to leave 1 core for OS 
    registerDoParallel(cluster)
    
    #
    # Step 2: configure trainControl() object for k-fold cross validation with
    #         5 folds
    #
    
    fitControl <- trainControl(method = "cv",
                               number = 5,
                               allowParallel = TRUE)
    
    #
    # Step 3: develop training model
    #
    
    system.time(fit <- train(Class ~ ., method="rf",data=Sonar,trControl = fitControl))
    
    #
    # Step 4: de-register cluster
    #
    stopCluster(cluster)
    registerDoSEQ()
    #
    # Step 5: evaluate model fit 
    #
    fit
    fit$resample
    confusionMatrix.train(fit)
    #average OOB error from final model
    mean(fit$finalModel$err.rate[,"OOB"])
    
    plot(fit,main="Accuracy by Predictor Count")
    varImpPlot(fit$finalModel,
               main="Variable Importance Plot: Random Forest")
    sessionInfo()