I have the following problem. In a data set from N subjects I have several samples per subject. I want to train a model on the data set, but I would like to make sure that in each resampling, in the training set there are no replicates of the subjects.
Alternatively, I would block the cross-validation by subject. Is that possible?
Without the caret package, I would do something like that (mock code)
subjects <- paste0("X", 1:10)
samples <- rep(subjects, each=5)
x <- matrix(runif(50 * 10), nrow=50)
loocv <- function(x, samples) {
for(i in 1:nrow(x)) {
test <- x[i,]
train <- x[ samples != samples[i],]
# create the model from train and predict for test
or, alternatively,
looSubjCV <- function(x, samples, subjects) {
for(i in 1:length(subjects)) {
test <- x[ samples == subjects[i], ]
train <- x[ samples != subjects[i], ]
# create the model from train and predict for test
Otherwise, the presence of other samples from the same subject will result in overfitting of the model.
Not directly but you can definitely do it using the index
and indexOut
arguments to trainControl
. Here is an example using 10-fold CV:
subjects <- as.character(unique(Orthodont$Subject))
## figure out folds at the subject level
sub_folds <- createFolds(y = subjects, list = TRUE, returnTrain = TRUE)
## now create the mappings to which *rows* are in the training set
## based on which subjects are left in or out
in_train <- holdout <- vector(mode = "list", length = length(sub_folds))
row_index <- 1:nrow(Orthodont)
for(i in seq(along = sub_folds)) {
## Which subjects are in fold i
sub_in <- subjects[sub_folds[[i]]]
## which rows of the data correspond to those subjects
in_train[[i]] <- row_index[Orthodont$Subject %in% sub_in]
holdout[[i]] <- row_index[!(Orthodont$Subject %in% sub_in)]
names(in_train) <- names(holdout) <- names(sub_folds)
ctrl <- trainControl(method = "cv",
savePredictions = TRUE,
index = in_train,
indexOut = holdout)
mod <- train(distance ~ (age+Sex)^2, data = Orthodont,
method = "lm",
trControl = ctrl)
first_fold <- subset(mod$pred, Resample == "Fold01")
## These were used to fit the model
## These were heldout: