Search code examples
rmachine-learningrandom-forestrpart

How to make a decision tree chart using random forest and filtering it by category


I'm learning machine learning in R and making a decision tree in R of expired products, where i have the following data:

Product, Category, Temperature, Expire_Day, Rotation_Day, Weight, State
Tapa, Pulpa, 0, 30, 21, 4.21, No
Tapa, Pulpa, 0, 30, 21, 3.82, Expire
Nalga, Pulpa, 0, 30, 25, 6.10, No
Nalga, Pulpa, 0, 30, 25, 5, Expire
Costeleta, Bife, 7, 5, 3, 1.10, No
Costeleta, Bife, 7, 5, 3, 2.25, No
Costeleta, Bife, 7, 5, 3, 0.9, Expire
Brazuelo, Bife, 7, 5, 3, 2.5, No

With this i create the data model by passing the Product and Category series to vector using dummyVars and normalizing the Weight using MinMaxScalar, for Temperature, Expire_Day and Rotation_Day there is a proximity relationship so i didn`t do any conversion and finally convert the State to Factor

The final model is:

Product.Tapa, Product.Nalga, Product.Costeleta, Product.Brazuelo, Category.Pulpa, Category.Bife, Temperature, Expire_Day, Rotation_Day, Weight, State
1, 0, 0, 0, 1, 0, 0, 30, 21, 0.9, No
1, 0, 0, 0, 1, 0, 0, 30, 21, 0.78, Expire
0, 1, 0, 0, 1, 0, 0, 30, 25, 0.99, No
0, 1, 0, 0, 1, 0, 0, 30, 25, 0.72, Expire
0, 0, 1, 0, 0, 1, 7, 5, 3, 0.12, No
0, 0, 1, 0, 0, 1, 7, 5, 3, 0.22, No
0, 0, 1, 0, 0, 1, 7, 5, 3, 0.88, Expire
0, 0, 0, 1, 0, 1, 7, 5, 3, 0.5, No

With this model i create the tree using random forest with the following code:

mtry <- 6
ntree <- 24

rf_model <- randomForest(result ~ .,
       data = trainData,
       mtry = mtry,
       ntree = ntree,
       trControl = control,
       varimp = TRUE,
       importance = TRUE,
       weight = data_weights,
       oob_score = FALSE)

Up to this point, if i predict i have a precision of 0.90 and works quite well, but i can't make work the decision tree chart where i have 2 conditions to meet:

a- The decision tree must begin with the columns of Expire_Day and Rotation_Day, which are the most important in the series

b- Be able to filter or classify the tree by Category, for example i have to be able to see the tree only of "Pulpa" without showing what corresponds to "Bife" and then change and be able to see only "Bife" or the entire tree if They ask me

i haven't found a way to make it work yet, how should i do it?


Solution

  • I guess you ran into some misunderstandings. I've read you fitted a randomForest model and now you would like to plot the actual decision tree of the model.

    In general it is to note here, randomForest is an ensemble model. Which means, there is not one tree, in order to get the best output (e.g. your 0.9 precision) multiple model results are combined. So the result you get is the combined result from multiple single decision trees. How many single decision trees are combined is defined by the hyperparameter ntree. The default for the randomForest package is hereby 500.

    Thus, your rf_model is a combination of 500 single trees. Here an example for the iris dataset, since I do not have your dataset.

    library("randomForest")
    iris_model <- randomForest(Species ~., data = iris) 
    getTree(iris_model, k = 2)
    

    With getTree() you could have a look at the single decision trees used in your randomForest model. With k you specify, which of the 500 trees to look at. The output is in text form. But does not make too much sense to look too much into this.

    The issue with these trees is, by design (would be good to read more about random forests), one of these single tree will not give you a very good precision. Compared with a single tree e.g. CART tree you would fit, the single trees of the randomForests are reduced in complexity, but it is their combination, what makes them powerful and resistant against overfitting.

    There is some research to get one single 'representative' tree from these randomForest ensembles, e.g. you can use the reprtree package

    reprtree is a package to implement the concept of representative trees from ensembles of tree-based machines introduced by Banerjee, et al, 2012.

    Using our iris_model we created above the code would look like this:

    # Must be installed from github
    devtools::install_github("araastat/reprtree")
    
    library("reprtree")
    reprtree:::plot.getTree(iris_model)
    

    This will get you a tree like this:

    enter image description here

    But overall it sounds, you just want to have a nice looking decision tree, you can show someone to make a decision. In this case you could also just fit a single CART tree and plot this one. (randomForest itself as a model might have a higher precision, but the single CART tree has most likely a higher precision than a single one of the 500 trees that are combined to the random forest)

    For this you can use the rpart package in R. In combination with the rpart.plot package you can create really nice plots of your tree.

    library("rpart")
    library("rpart.plot")
    
    #fit new model with rpart
    iris_model2 <- rpart(Species ~., data = iris)
    
    # plot the tree
    rpart.plot(iris_model2)
    

    enter image description here

    About your special requirements:

    a- The decision tree must begin with the columns of Expire_Day and Rotation_Day, which are the most important in the series

    Actually, independent, which tree based model you are actually using - rpart, randomForest or other decision tree variations (there are many) - the splits are actually derived from the data you used to train the model. The splits are designed to give you the best possible model. So might be just your intuition, that Rotation_Day and Expire_Day are the most important factors - if they are not prominently included in the model the data says otherwise (or something else went wrong).

    b- Be able to filter or classify the tree by Category, for example i have to be able to see the tree only of "Pulpa" without showing what corresponds to "Bife" and then change and be able to see only "Bife" or the entire tree if They ask me

    You could filter the input data used to build the model to just include "Pulpa" and build and plot the tree then. Gives a nice insight, how "Pulpa" would be classified. (but might not be the exact tree used in an combined model)