Search code examples
rrandom-forestdecision-treecaret

predict() random forests - extract predictions from individual trees


I built a random forest model called iris_class *.

set.seed(10)
index_row <- sample(2, 
                    nrow(iris), 
                    replace = T, 
                    prob = c(0.7, 0.3)
)  

train_data <- iris[index_row == 1,]
test_data <- iris[index_row == 2,]

iris_class <- randomForest(Species ~., 
                                data = train_data)

This is how iris_class looks like:

> iris_class

Call:
 randomForest(formula = Species ~ ., data = train_data) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 2

        OOB estimate of  error rate: 4.5%
Confusion matrix:
           setosa versicolor virginica class.error
setosa         38          0         0  0.00000000
versicolor      0         39         2  0.04878049
virginica       0          3        29  0.09375000

I then use it to make predictions using the predict() function.

predictions<- predict(iris_class, test_data[,-5], type = "response")

iris_class is made of 500 individual trees. If I understand correctly, when I run predict() using iris_class, 500 trees are being generated, each gives a classification, and I am being shown the average result of those 500 trees.

My questions is:

is there a way to extract the prediction of each of the 500 trees?

In other words, can the predict() function return an object that, for each item being classified, will have 500 rows saying setosa, versicolor or virginica. Or a summarised version of such an object (shown below). The purpose is: I want to know how "confident" the model actually is. When it predicts a plant to be setosa, is it 450 trees said setosa and 50 said something else, or is it 251 vs 249?

What line of code will extract predictions for individual trees?

My ideal output would look something like this:

> predictions_info
       setosa versicolor  virginica       pred
1  0.01517536 0.55449239 0.43033225 versicolor
2  0.21957988 0.71962024 0.06079987 versicolor
3  0.28146250 0.36777757 0.35075993 versicolor
4  0.51503150 0.41750308 0.06746543     setosa
5  0.25832598 0.10796878 0.63370523  virginica
6  0.24603616 0.07558151 0.67838233  virginica
7  0.02323489 0.41547464 0.56129047  virginica
8  0.41155830 0.49214444 0.09629726 versicolor
9  0.30217529 0.39852784 0.29929686 versicolor
10 0.45923782 0.49147493 0.04928725 versicolor
11 0.70479092 0.27648912 0.01871996     setosa
12 0.34489442 0.02606726 0.62903832  virginica
13 0.15553471 0.18903000 0.65543530  virginica
[...]

Where the pred column is what the predict function currently returns, and the first 3 columns show what proportion of the 500 trees gave which prediction. (These numbers and predictions are made up! and they don't match model output)

*This example is originally from this website (modified by me): https://rpubs.com/Jay2548/519589


Solution

  • Use predict.all = T. Then compute whatever you want with all the predictions. Be carefull, you will have a matrix of size : length(dataset) x number of trees