Search code examples
rnaivebayes

How to produce a confusion matrix and find the misclassification rate of the Naïve Bayes Classifier?


Using the iris dataset in R, I'm trying to fit a a Naïve Bayes classifier to the iris training data so I could Produce a confusion matrix of the training data set (predicted vs actual) for the naïve bayes classifier, what is the misclassification rate of the Naïve Bayes Classifier?

Here's my code so far:

 iris$spl=sample.split(iris,SplitRatio=0.8)
 train=subset(iris, iris$spl==TRUE)
 test=subset(iris, iris$spl==FALSE)

 iris.nb <- naiveBayes(Species~.,data = train)
 iris.nb

 nb_test_predict <- predict(iris.nb, train)

Any suggestions on how to approach this problem?


Solution

  • Package caret includes confusionMatrix function that returns a very complete output.

    library(e1071)
    library(caTools)
    library(caret)
    
    iris$spl = sample.split(iris, SplitRatio = 0.8)
    train <- subset(iris, iris$spl == TRUE)
    test <- subset(iris, iris$spl == FALSE)
    
    iris.nb <- naiveBayes(Species ~ ., data = train)
    
    nb_train_predict <- predict(iris.nb, test[ , names(test) != "Species"])
    
    cfm <- confusionMatrix(nb_train_predict, test$Species)
    cfm
    
    # Confusion Matrix and Statistics
    # 
    # Reference
    # Prediction   setosa versicolor virginica
    # setosa         17          0         0
    # versicolor      0         14         1
    # virginica       0          2        16
    # 
    # Overall Statistics
    # 
    # Accuracy : 0.94            
    # 95% CI : (0.8345, 0.9875)
    # No Information Rate : 0.34            
    # P-Value [Acc > NIR] : < 2.2e-16       
    # 
    # Kappa : 0.9099          
    # Mcnemar's Test P-Value : NA              
    # 
    # Statistics by Class:
    # 
    #                      Class: setosa Class: versicolor Class: virginica
    # Sensitivity                   1.00            0.8750           0.9412
    # Specificity                   1.00            0.9706           0.9394
    # Pos Pred Value                1.00            0.9333           0.8889
    # Neg Pred Value                1.00            0.9429           0.9688
    # Prevalence                    0.34            0.3200           0.3400
    # Detection Rate                0.34            0.2800           0.3200
    # Detection Prevalence          0.34            0.3000           0.3600
    # Balanced Accuracy             1.00            0.9228           0.9403
    

    To display confusion matrix as ggplot graphic:

    library(ggplot2)
    library(scales)
    
    ggplotConfusionMatrix <- function(m){
      mytitle <- paste("Accuracy", percent_format()(m$overall[1]),
                       "Kappa", percent_format()(m$overall[2]))
      p <-
        ggplot(data = as.data.frame(m$table) ,
               aes(x = Reference, y = Prediction)) +
        geom_tile(aes(fill = log(Freq)), colour = "white") +
        scale_fill_gradient(low = "white", high = "steelblue") +
        geom_text(aes(x = Reference, y = Prediction, label = Freq)) +
        theme(legend.position = "none") +
        ggtitle(mytitle)
      return(p)
    }
    
    ggplotConfusionMatrix(cfm)
    

    enter image description here