Search code examples
rr-caretk-fold

Caret package cross-validation summary in R


Assume i have a K-folds list with K=10, each element contains caret classification performance output:

dput(transformed_conf_matrices$Fold01)
structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 
0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1), dim = c(11L, 
3L), dimnames = list(c("Sensitivity", "Specificity", "Pos Pred Value", 
"Neg Pred Value", "Precision", "Recall", "F1", "Prevalence", 
"Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
)))



transformed_conf_matrices$Fold01
                     Class: setosa Class: versicolor Class: virginica
Sensitivity              1.0000000         1.0000000        1.0000000
Specificity              1.0000000         1.0000000        1.0000000
Pos Pred Value           1.0000000         1.0000000        1.0000000
Neg Pred Value           1.0000000         1.0000000        1.0000000
Precision                1.0000000         1.0000000        1.0000000
Recall                   1.0000000         1.0000000        1.0000000
F1                       1.0000000         1.0000000        1.0000000
Prevalence               0.3333333         0.3333333        0.3333333
Detection Rate           0.3333333         0.3333333        0.3333333
Detection Prevalence     0.3333333         0.3333333        0.3333333
Balanced Accuracy        1.0000000         1.0000000        1.0000000

In this special case , transformed_conf_matrices$Fold01 to transformed_conf_matrices$Fold10 are equal ( same values ).

I would like to have the mean and variance of those metrics. I did many attempts with lapply without success.

The K-folds list :

dput(transformed_conf_matrices)
list(Fold01 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold02 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold03 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold04 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold05 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold06 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold07 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold08 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold09 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))), Fold10 = structure(c(1, 1, 1, 1, 1, 1, 1, 0.333333333333333, 
0.333333333333333, 0.333333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.333333333333333, 0.333333333333333, 0.333333333333333, 1, 1, 
1, 1, 1, 1, 1, 1, 0.333333333333333, 0.333333333333333, 0.333333333333333, 
1), dim = c(11L, 3L), dimnames = list(c("Sensitivity", "Specificity", 
"Pos Pred Value", "Neg Pred Value", "Precision", "Recall", "F1", 
"Prevalence", "Detection Rate", "Detection Prevalence", "Balanced Accuracy"
), c("Class: setosa", "Class: versicolor", "Class: virginica"
))))

Solution

  • I assume the data to be small. Doing things twice doesn't really matter then:

    classes = gsub("Class: ", "", colnames(transformed_conf_matrices[[1L]]))
    transformed_conf_matrices = do.call(cbind, transformed_conf_matrices)
    rowVars = \(x, ...) { rowSums((x - rowMeans(x, ...)) ^ 2L, ...) / (nrow(x) - 1L) }
    

    Then we vapply:

    > vapply(classes, \(i) 
    +        rowMeans(transformed_conf_matrices[, grepl(i, colnames(transformed_conf_matrices))]), 
    +        numeric(length = nrow(transformed_conf_matrices)))
                            setosa versicolor virginica
    Sensitivity          1.0000000  1.0000000 1.0000000
    Specificity          1.0000000  1.0000000 1.0000000
    Pos Pred Value       1.0000000  1.0000000 1.0000000
    Neg Pred Value       1.0000000  1.0000000 1.0000000
    Precision            1.0000000  1.0000000 1.0000000
    Recall               1.0000000  1.0000000 1.0000000
    F1                   1.0000000  1.0000000 1.0000000
    Prevalence           0.3333333  0.3333333 0.3333333
    Detection Rate       0.3333333  0.3333333 0.3333333
    Detection Prevalence 0.3333333  0.3333333 0.3333333
    Balanced Accuracy    1.0000000  1.0000000 1.0000000
    > vapply(classes, \(i) 
    +        rowVars(transformed_conf_matrices[, grepl(i, colnames(transformed_conf_matrices))]), 
    +        numeric(length = nrow(transformed_conf_matrices)))
                         setosa versicolor virginica
    Sensitivity               0          0         0
    Specificity               0          0         0
    Pos Pred Value            0          0         0
    Neg Pred Value            0          0         0
    Precision                 0          0         0
    Recall                    0          0         0
    F1                        0          0         0
    Prevalence                0          0         0
    Detection Rate            0          0         0
    Detection Prevalence      0          0         0
    Balanced Accuracy         0          0         0
    

    rowVars() from here.