Search code examples
rmachine-learningr-caretconfusion-matrixensemble-learning

Machine Learning in R - confusion matrix of an ensemble


I'm trying to access the overall accuracy (or confusionMatrix) of an across a number of classifiers but can't seem to find how to report this.

Already tried:

confusionMatrix(fits_predicts,reference=(mnist_27$test$y))

Error in table(data, reference, dnn = dnn, ...) : all arguments must have the same length

library(caret)
library(dslabs)
set.seed(1)
data("mnist_27")

models <- c("glm", "lda",  "naive_bayes",  "svmLinear", 
            "gamboost",  "gamLoess", "qda", 
            "knn", "kknn", "loclda", "gam",
            "rf", "ranger",  "wsrf", "Rborist", 
            "avNNet", "mlp", "monmlp",
            "adaboost", "gbm",
            "svmRadial", "svmRadialCost", "svmRadialSigma")

fits <- lapply(models, function(model){ 
  print(model)
  train(y ~ ., method = model, data = mnist_27$train)
}) 

names(fits) <- models

fits_predicts <- sapply(fits, function(fits){ predict(fits,mnist_27$test)
  })

I'd like to report the confusionMatrix across the different models.


Solution

  • You are not training any ensemble; you are just training a list of several models, without combining them in any way, which is definitely not an ensemble.

    Given that, the error you get is not unexpected, since confusionMatrix expects a single prediction (which would be the case if you had an ensemble indeed), and not multiple ones.

    Keeping your list for simplicity to only your first 4 models, and changing slightly your fits_predicts definition, so that it gives a dataframe, i.e.:

    models <- c("glm", "lda",  "naive_bayes",  "svmLinear")
    
    fits_predicts <- as.data.frame( sapply(fits, function(fits){ predict(fits,mnist_27$test)
    }))
    
    # rest of your code as-is
    

    here is how you can get the confusion matrices for each one of your models:

    cm <- lapply(fits_predicts, function(fits_predicts){confusionMatrix(fits_predicts,reference=(mnist_27$test$y))
    })
    

    which gives

    > cm
    $glm
    Confusion Matrix and Statistics
    
              Reference
    Prediction  2  7
             2 82 26
             7 24 68
    
                   Accuracy : 0.75           
                     95% CI : (0.684, 0.8084)
        No Information Rate : 0.53           
        P-Value [Acc > NIR] : 1.266e-10      
    
                      Kappa : 0.4976         
     Mcnemar's Test P-Value : 0.8875         
    
                Sensitivity : 0.7736         
                Specificity : 0.7234         
             Pos Pred Value : 0.7593         
             Neg Pred Value : 0.7391         
                 Prevalence : 0.5300         
             Detection Rate : 0.4100         
       Detection Prevalence : 0.5400         
          Balanced Accuracy : 0.7485         
    
           'Positive' Class : 2              
    
    
    $lda
    Confusion Matrix and Statistics
    
              Reference
    Prediction  2  7
             2 82 26
             7 24 68
    
                   Accuracy : 0.75           
                     95% CI : (0.684, 0.8084)
        No Information Rate : 0.53           
        P-Value [Acc > NIR] : 1.266e-10      
    
                      Kappa : 0.4976         
     Mcnemar's Test P-Value : 0.8875         
    
                Sensitivity : 0.7736         
                Specificity : 0.7234         
             Pos Pred Value : 0.7593         
             Neg Pred Value : 0.7391         
                 Prevalence : 0.5300         
             Detection Rate : 0.4100         
       Detection Prevalence : 0.5400         
          Balanced Accuracy : 0.7485         
    
           'Positive' Class : 2              
    
    
    $naive_bayes
    Confusion Matrix and Statistics
    
              Reference
    Prediction  2  7
             2 88 23
             7 18 71
    
                   Accuracy : 0.795           
                     95% CI : (0.7323, 0.8487)
        No Information Rate : 0.53            
        P-Value [Acc > NIR] : 5.821e-15       
    
                      Kappa : 0.5873          
     Mcnemar's Test P-Value : 0.5322          
    
                Sensitivity : 0.8302          
                Specificity : 0.7553          
             Pos Pred Value : 0.7928          
             Neg Pred Value : 0.7978          
                 Prevalence : 0.5300          
             Detection Rate : 0.4400          
       Detection Prevalence : 0.5550          
          Balanced Accuracy : 0.7928          
    
           'Positive' Class : 2               
    
    
    $svmLinear
    Confusion Matrix and Statistics
    
              Reference
    Prediction  2  7
             2 81 24
             7 25 70
    
                   Accuracy : 0.755           
                     95% CI : (0.6894, 0.8129)
        No Information Rate : 0.53            
        P-Value [Acc > NIR] : 4.656e-11       
    
                      Kappa : 0.5085          
     Mcnemar's Test P-Value : 1               
    
                Sensitivity : 0.7642          
                Specificity : 0.7447          
             Pos Pred Value : 0.7714          
             Neg Pred Value : 0.7368          
                 Prevalence : 0.5300          
             Detection Rate : 0.4050          
       Detection Prevalence : 0.5250          
          Balanced Accuracy : 0.7544          
    
           'Positive' Class : 2       
    

    And you can also access the individual confusion matrices per model, e.g. for lda:

    > cm['lda']
    $lda
    Confusion Matrix and Statistics
    
              Reference
    Prediction  2  7
             2 82 26
             7 24 68
    
                   Accuracy : 0.75           
                     95% CI : (0.684, 0.8084)
        No Information Rate : 0.53           
        P-Value [Acc > NIR] : 1.266e-10      
    
                      Kappa : 0.4976         
     Mcnemar's Test P-Value : 0.8875         
    
                Sensitivity : 0.7736         
                Specificity : 0.7234         
             Pos Pred Value : 0.7593         
             Neg Pred Value : 0.7391         
                 Prevalence : 0.5300         
             Detection Rate : 0.4100         
       Detection Prevalence : 0.5400         
          Balanced Accuracy : 0.7485         
    
           'Positive' Class : 2