Search code examples
rneural-networkcross-validationr-caretnnet

Specify Cross Validation Folds with caret


Hello and thanks in advance. I'm using caret to cross validate a neural-network from the nnet package. In the method parameter for the trainControl function I can specify my cross-validation type, but all of these choose the observations at random to cross-validate against. Is there anyway I can use caret to cross-validate on specific observations in my data by either an ID or a hard-coded parameter? For example here's my current code:

library(nnet) 
library(caret) 
library(datasets) 

data(iris) 

train.control <- trainControl( 
    method = "repeatedcv" 
    , number = 4 
    , repeats = 10 
    , verboseIter = T 
    , returnData = T 
    , savePredictions = T 
    ) 

tune.grid <- expand.grid( 
    size = c(2,4,6,8)
    ,decay = 2^(-3:1) 
    ) 

nnet.train <- train( 
    x = iris[,1:4] 
    , y = iris[,5] 
    , method = "nnet" 
    , preProcess = c("center","scale")  
    , metric = "Accuracy" 
    , trControl = train.control 
    , tuneGrid = tune.grid 
    ) 
nnet.train 
plot(nnet.train)

Suppose I wanted to add another column CV_GROUP to the iris data frame and I wanted caret to cross-validate the neural-network on observations with a value of 1 for that column:

iris$CV_GROUP <- c(rep.int(0,times=nrow(iris)-20), rep.int(1,times=20))

Is this possible with caret?


Solution

  • Use index and indexOut control options. I coded a way to implement this that let's you select the number of repeats and folds that you want:

    library(nnet)
    library(caret)
    library(datasets)
    library(data.table)
    library(e1071)
    
    r <- 2 # number of repeats
    k <- 5 # number of folds
    data(iris)
    iris <- data.table(iris)
    
    # Create folds and repeats here - you could create your own if you want #
    set.seed(343)
    for (i in 1:r) {
        newcol <- paste('fold.num',i,sep='')
        iris <- iris[,eval(newcol):=sample(1:k, size=dim(iris)[1], replace=TRUE)]
    }
    
    folds.list.out <- list()
    folds.list <- list()
    list.counter <- 1
    for (y in 1:r) {
        newcol <- paste('fold.num', y, sep='')
        for (z in 1:k) {
            folds.list.out[[list.counter]] <- which(iris[,newcol,with=FALSE]==z)
            folds.list[[list.counter]] <- which(iris[,newcol,with=FALSE]!=z)
            list.counter <- list.counter + 1
        }
        iris <- iris[,!newcol,with=FALSE]
    }
    
    tune.grid <- expand.grid( 
        size = c(2,4,6,8)
        ,decay = 2^(-3:1) 
        ) 
    
    train.control <- trainControl( 
        index=folds.list
        , indexOut=folds.list.out
        , verboseIter = T 
        , returnData = T 
        , savePredictions = T 
        ) 
    
    iris <- data.frame(iris)
    
    nnet.train <- train( 
        x = iris[,1:4] 
        , y = iris[,5] 
        , method = "nnet" 
        , preProcess = c("center","scale")  
        , metric = "Accuracy" 
        , trControl = train.control 
        , tuneGrid = tune.grid 
        ) 
    
    nnet.train
    plot(nnet.train)