Search code examples
rmachine-learningcross-validationr-carettraining-data

In the train method what's the relationship between tuneGrid and trControl?


The preferred method in R to train known ML models is to use the caret package and its generic train method. My question is what's the relationship between the tuneGrid and trControl parameters? as they are undoubtedly related and I can't figure out their relationship by reading the documentation ... for example:

library(caret)  
# train and choose best model using cross validation
df <- ... # contains input data
control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)
fit <- train(y ~ ., method = "knn", 
             data = df,
             tuneGrid = data.frame(k = seq(9, 71, 2)),
             trControl = control)

If I run the code above what's happening? how do the 10 CV folds each containing 90% of the data as per the trainControl definition are combined with the 32 levels of k?

More concretely:

  • I have 32 levels for the parameter k.
  • I also have 10 CV folds.

Is the k-nearest neighbors model trained 32*10 times? or otherwise?


Solution

  • Yes, you are correct. You partition your training data into 10 sets, say 1..10. Starting with set 1, you train your model using all of 2..10 (90% of the training data) and test it on set 1. This is repeated again for set2, set3.. It's a total of 10 times, and you have 32 values of k to test, hence 32 * 10 = 320.

    You can also pull out this cv results using the returnResamp function in trainControl. I simplify it to 3-fold and 4 values of k below:

    df <- mtcars
    set.seed(100)
    control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")
    fit <- train(mpg  ~ ., method = "knn", 
                 data = mtcars,
                 tuneGrid = data.frame(k = 2:5),
                 trControl = control)
    
    resample_results = fit$resample
    resample_results
           RMSE  Rsquared      MAE k Resample
    1  3.502321 0.7772086 2.483333 2    Fold1
    2  3.807011 0.7636239 2.861111 3    Fold1
    3  3.592665 0.8035741 2.697917 4    Fold1
    4  3.682105 0.8486331 2.741667 5    Fold1
    5  2.473611 0.8665093 1.995000 2    Fold2
    6  2.673429 0.8128622 2.210000 3    Fold2
    7  2.983224 0.7120910 2.645000 4    Fold2
    8  2.998199 0.7207914 2.608000 5    Fold2
    9  2.094039 0.9620830 1.610000 2    Fold3
    10 2.551035 0.8717981 2.113333 3    Fold3
    11 2.893192 0.8324555 2.482500 4    Fold3
    12 2.806870 0.8700533 2.368333 5    Fold3
    
    # we manually calculate the mean RMSE for each parameter
    tapply(resample_results$RMSE,resample_results$k,mean)
           2        3        4        5 
    2.689990 3.010492 3.156360 3.162392
    
    # and we can see it corresponds to the final fit result
    fit$results
    k     RMSE  Rsquared      MAE    RMSESD RsquaredSD     MAESD
    1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.4376844
    2 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.4067066
    3 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.1122577
    4 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581