I built a random forest model called iris_class
*.
set.seed(10)
index_row <- sample(2,
nrow(iris),
replace = T,
prob = c(0.7, 0.3)
)
train_data <- iris[index_row == 1,]
test_data <- iris[index_row == 2,]
iris_class <- randomForest(Species ~.,
data = train_data)
This is how iris_class
looks like:
> iris_class
Call:
randomForest(formula = Species ~ ., data = train_data)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 2
OOB estimate of error rate: 4.5%
Confusion matrix:
setosa versicolor virginica class.error
setosa 38 0 0 0.00000000
versicolor 0 39 2 0.04878049
virginica 0 3 29 0.09375000
I then use it to make predictions using the predict()
function.
predictions<- predict(iris_class, test_data[,-5], type = "response")
iris_class
is made of 500 individual trees. If I understand correctly, when I run predict()
using iris_class
, 500 trees are being generated, each gives a classification, and I am being shown the average result of those 500 trees.
My questions is:
is there a way to extract the prediction of each of the 500 trees?
In other words, can the predict()
function return an object that, for each item being classified, will have 500 rows saying setosa
, versicolor
or virginica
. Or a summarised version of such an object (shown below). The purpose is: I want to know how "confident" the model actually is. When it predicts a plant to be setosa
, is it 450 trees said setosa
and 50 said something else, or is it 251 vs 249?
What line of code will extract predictions for individual trees?
My ideal output would look something like this:
> predictions_info
setosa versicolor virginica pred
1 0.01517536 0.55449239 0.43033225 versicolor
2 0.21957988 0.71962024 0.06079987 versicolor
3 0.28146250 0.36777757 0.35075993 versicolor
4 0.51503150 0.41750308 0.06746543 setosa
5 0.25832598 0.10796878 0.63370523 virginica
6 0.24603616 0.07558151 0.67838233 virginica
7 0.02323489 0.41547464 0.56129047 virginica
8 0.41155830 0.49214444 0.09629726 versicolor
9 0.30217529 0.39852784 0.29929686 versicolor
10 0.45923782 0.49147493 0.04928725 versicolor
11 0.70479092 0.27648912 0.01871996 setosa
12 0.34489442 0.02606726 0.62903832 virginica
13 0.15553471 0.18903000 0.65543530 virginica
[...]
Where the pred
column is what the predict
function currently returns, and the first 3 columns show what proportion of the 500 trees gave which prediction. (These numbers and predictions are made up! and they don't match model output)
*This example is originally from this website (modified by me): https://rpubs.com/Jay2548/519589
Use predict.all = T. Then compute whatever you want with all the predictions. Be carefull, you will have a matrix of size : length(dataset) x number of trees