Search code examples
rr-caretnaivebayes

R - Generate confusion matrix and ROC for model generated by multinomial_naive_bayes() function


I have a data set with many factor/categorical/nominal columns/variables/features. I need to create a multinomial naive bayes classifier for this data. I tried using the caret library but I don't think that was doing a multinomial naive bayes, I think it was doing gaussian naive bayes, details here. I have now discovered multinomial_naive_bayes() which seems to be perfect. It seems to handle nulls in the predictor variables and a variable with only 1 value without complaining.

The issue is, I can't figure out how to do my "post processing/analysis" of the model generated by the multinomial_naive_bayes() function. I want to get a caret style confusionMatrix on the model and also on the prediction output vs the test data to assess the classifier. I would also like to generate a ROC curve. How can I do this?

I have included the sample/reference/example from the documentation of multinomial_naive_bayes() below, how would I update this code to get my confusionMatricies and ROC curve.

From: R Package 'naivebayes', section: multinomial_naive_bayes pg 10

library(naivebayes)

### Simulate the data:
cols <- 10 ; rows <- 100
M <- matrix(sample(0:5, rows * cols, TRUE, prob = c(0.95, rep(0.01, 5))), nrow = rows, ncol = cols)
y <- factor(sample(paste0("class", LETTERS[1:2]), rows, TRUE, prob = c(0.3,0.7)))
colnames(M) <- paste0("V", seq_len(ncol(M)))
laplace <- 1

### Train the Multinomial Naive Bayes
mnb <- multinomial_naive_bayes(x = M, y = y, laplace = laplace)
summary(mnb)
    
# Classification
head(predict(mnb, newdata = M, type = "class")) # head(mnb %class% M)

# Posterior probabilities
head(predict(mnb, newdata = M, type = "prob")) # head(mnb %prob% M)

# Parameter estimates
coef(mnb)

Solution

  • You can use the caret function confusionMatrix:

    library(caret)
    pred = predict(mnb, newdata = M, type = "class")
    confusionMatrix(table(pred,y))
    
    Confusion Matrix and Statistics
    
            y
    pred     classA classB
      classA     10      3
      classB     20     67
    

    Or if the factor levels are the same:

    confusionMatrix(pred,y)
    

    For ROC curve, you need to provide the probability of the prediction:

    library(pROC)
    roc_ = roc(y,predict(mnb, newdata = M, type ="prob")[,2])
    
    plot(roc_)
    

    enter image description here