Search code examples
scalaapache-sparkapache-spark-mllibdecision-tree

Get the default number of elements per leaf in a Decision Tree of Spark MLlib


I want to get the default number of elements per leaf in a Spark MLlib Decision Tree, if it is possible.

I've been reading here https://spark.apache.org/docs/latest/mllib-decision-tree.html and also trying to find something in https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala but I couldn't find the information that I need.

I know about the minInstancesPerNode Strategy parameter, but isn't what I want.

Any ideas? Thanks!


Solution

  • A Spark DecisionTreeClassifier has several parameters that you can set with setZYZ methods before training time. A number of methods will help you regularizing the tree and avoiding overfitting. e.g.

    • setMinInstancesPerNode: The minimum number of training records that have to be present in a node/leaf to be valid. The the node/leaf has less than minInstances it will be rolled up into the parent
    • setMaxDepth: The maximum depth of the tree after which the tree will stop growing.
    • setMinInfoGain: The minimum information gain for a split to occur

    Once you train (.fit) a Spark decision tree and then predict (.transform) you will have 3 additional columns in your DataFrame (for classification):

    • predictionCol: "Predicted label"
    • rawPredictionCol: "Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction"
    • probabilityCol: "Vector of length # classes equal to rawPrediction normalized to a multinomial distribution"

    The column rawPredictionCol might be what you are looking for. It tells you how many instances of each class ended up in the leaf after building the tree at training time. The predicted label is the class with the highest count. The probabilityCol (derived from rawPredictionCol) captures the "confidence" in the prediction. See: https://spark.apache.org/docs/latest/ml-classification-regression.html#output-columns