Search code examples
rrandom-forestpredictrowsum

How can I sum votes by category from the randomForest predict function in R?


This example code creates a dataframe with the initial column representing the majority vote from the 10 trees. The next 10 columns contain the categorical vote of each tree in the model. I want to create a chart to show the distribution of votes for each row. What is the best way to do that?

library(tidyverse)
library(caret)
library(randomForest)

train_index_cars <- as.vector(createDataPartition(mtcars[['cyl']],p=.8,list=FALSE,times=1))
mytrain <- mtcars[train_index_cars, ]
mytest <- mtcars[-train_index_cars, ]

car_forest <- randomForest(factor(cyl) ~., data= mytrain, ntree = 10, predict.all = T)
cartest_predicted <- as.data.frame(predict(car_forest, newdata =  mytest, predict.all = TRUE))

The output of the Merc 280 row in cartest_predicted looks something like this(excluding the last 6 trees)

id aggregrate individual.1 individual.2 individual.3 individual.4
Merc 280 6 6 8 6 4

I'd like to add three columns to each row that contains the count of the votes for each category (4, 6, 8) across trees. I'm envisioning the output like this:

individual.10 Votes_4 Votes_6 Votes_8
6 2 7 1

What is the best way to sum the columns across rows by condition? I can't seem to find exactly what I need out there. Does this output already exist as part of the randomForest package and I'm just overlooking it?


Solution

  • This should work:

    # Defining temporarily function, to be passed within apply().
    temp.fun = function(x) sum(x == i)
    
    for (i in unique(cartest_predicted$aggregate)) # Iterating over possible votes.
    {
      i = as.integer(i)
    
      cartest_predicted$temp = apply(cartest_predicted[, -1], MARGIN = 1, temp.fun) # Requested results.
      colnames(cartest_predicted)[dim(cartest_predicted)[[2]]] = paste("Votes", i, sep = "_") # Renaming new column.
    }
    

    The for loop is needed to iterate over all the possible votes that trees can cast. A temporarily function is created to perform the operation you need for each i, that is, for each possible vote. Such function is then used within apply() to be applied to each row of cartest_predict (notice MARGIN = 1). Finally, paste is used to rename columns.