Search code examples
rcross-validationindicesr-caret

Access indices of each CV fold for custom metric function in caret


I want to define my custom metric function in caret, but in this function I want to use additional information that is not used for training. I therefore need to have the indices (row numbers) of the data that is used in this fold for validation.

Here is a silly example:

generate data:

library(caret)
set.seed(1234)

x <- matrix(rnorm(10),nrow=5,ncol=2 )
y <- factor(c("y","n","y","y","n"))

priors <- c(1,3,2,7,9)

this is my example metric function, it should use information from the priors vector

my.metric <- function (data,
                   lev = NULL,
                   model = NULL) {
          out <- priors[-->INDICES.OF.DATA<--] + data$pred/data$obs   
          names(out) <- "MYMEASURE"
          out
}

myControl <- trainControl(summaryFunction = my.metricm, method="repeatedcv", number=10, repeats=2)

fit <- train(y=y,x=x, metric = "MYMEASURE",method="gbm", trControl = mControl)

to make this perhaps even more clear, I could use this in a survival setting where priors are days and use this in a Surv object to measure survival AUC in the metric function.

How can I do this in caret?


Solution

  • You can access the row numbers using data$rowIndex. Note that the summary function should return a single number as its metric (e.g. ROC, Accuracy, RMSE...). The above function seems to return a vector of length equal to the number of observations in the held out CV-data.

    If you're interested in seeing the resamples along with their predictions you can add print(data) to the my.metric function.

    Here's an example using your data (enlarged a bit) and Metrics::auc as the performance measure after multiplying the predicted class probabilities with the prior:

    library(caret)
    library(Metrics)
    
    set.seed(1234)
    x <- matrix(rnorm(100), nrow=100, ncol=2 )
    set.seed(1234)
    y <- factor(sample(x = c("y", "n"), size = 100, replace = T))
    
    priors <- runif(n = length(y), min = 0.1, max = 0.9)
    
    my.metric <- function(data, lev = NULL, model = NULL) 
    {
        # The performance metric should be a single number
        # data$y are the predicted probabilities of  
        # the observations in the fold belonging to class "y"
        out <- Metrics::auc(actual = as.numeric(data$obs == "y"),
                            predicted = priors[data$rowIndex] * data$y)
        names(out) <- "MYMEASURE"
        out
    }
    
    fitControl <- trainControl(method = "repeatedcv",
                               number = 10,
                               classProbs = T,
                               repeats = 2,
                               summaryFunction = my.metric)
    
    set.seed(1234)
    fit <- train(y = y, 
                 x = x,
                 metric = "MYMEASURE",
                 method="gbm", 
                 verbose = FALSE,
                 trControl = fitControl)
    fit
    
    # Stochastic Gradient Boosting 
    # 
    # 100 samples
    # 2 predictor
    # 2 classes: 'n', 'y' 
    # 
    # No pre-processing
    # Resampling: Cross-Validated (10 fold, repeated 2 times) 
    # 
    # Summary of sample sizes: 90, 90, 90, 90, 90, 89, ... 
    # 
    # Resampling results across tuning parameters:
    #     
    # interaction.depth  n.trees  MYMEASURE  MYMEASURE SD
    # 1                   50      0.5551667  0.2348496   
    # 1                  100      0.5682500  0.2297383   
    # 1                  150      0.5797500  0.2274042   
    # 2                   50      0.5789167  0.2246845   
    # 2                  100      0.5941667  0.2053826   
    # 2                  150      0.5900833  0.2186712   
    # 3                   50      0.5750833  0.2291999   
    # 3                  100      0.5488333  0.2312470   
    # 3                  150      0.5577500  0.2202638   
    # 
    # Tuning parameter 'shrinkage' was held constant at a value of 0.1
    # Tuning parameter 'n.minobsinnode' was held constant at a value of 10
    # MYMEASURE was used to select the optimal model using  the largest value. 
    

    I don't know too much about survival analysis but I hope this helps.