The preferred method in R to train known ML models is to use the caret
package and its generic train
method. My question is what's the relationship between the tuneGrid
and trControl
parameters? as they are undoubtedly related and I can't figure out their relationship by reading the documentation ... for example:
library(caret)
# train and choose best model using cross validation
df <- ... # contains input data
control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)
fit <- train(y ~ ., method = "knn",
data = df,
tuneGrid = data.frame(k = seq(9, 71, 2)),
trControl = control)
If I run the code above what's happening? how do the 10 CV folds each containing 90% of the data as per the trainControl
definition are combined with the 32 levels of k
?
More concretely:
k
.Is the k-nearest neighbors model trained 32*10 times? or otherwise?
Yes, you are correct. You partition your training data into 10 sets, say 1..10. Starting with set 1, you train your model using all of 2..10 (90% of the training data) and test it on set 1. This is repeated again for set2, set3.. It's a total of 10 times, and you have 32 values of k to test, hence 32 * 10 = 320.
You can also pull out this cv results using the returnResamp function in trainControl. I simplify it to 3-fold and 4 values of k below:
df <- mtcars
set.seed(100)
control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")
fit <- train(mpg ~ ., method = "knn",
data = mtcars,
tuneGrid = data.frame(k = 2:5),
trControl = control)
resample_results = fit$resample
resample_results
RMSE Rsquared MAE k Resample
1 3.502321 0.7772086 2.483333 2 Fold1
2 3.807011 0.7636239 2.861111 3 Fold1
3 3.592665 0.8035741 2.697917 4 Fold1
4 3.682105 0.8486331 2.741667 5 Fold1
5 2.473611 0.8665093 1.995000 2 Fold2
6 2.673429 0.8128622 2.210000 3 Fold2
7 2.983224 0.7120910 2.645000 4 Fold2
8 2.998199 0.7207914 2.608000 5 Fold2
9 2.094039 0.9620830 1.610000 2 Fold3
10 2.551035 0.8717981 2.113333 3 Fold3
11 2.893192 0.8324555 2.482500 4 Fold3
12 2.806870 0.8700533 2.368333 5 Fold3
# we manually calculate the mean RMSE for each parameter
tapply(resample_results$RMSE,resample_results$k,mean)
2 3 4 5
2.689990 3.010492 3.156360 3.162392
# and we can see it corresponds to the final fit result
fit$results
k RMSE Rsquared MAE RMSESD RsquaredSD MAESD
1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.4376844
2 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.4067066
3 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.1122577
4 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581