Search code examples
pythonapache-sparkapache-spark-mllibapache-spark-ml

Equivalent of pyspark.mllib.tree.DecisionTreeModel.toDebugString() in pyspark.ml.classification.DecisionTreeClassificationModel - IN PYTHON


This is essentially the same question as:

BUT for pyspark.

I used to be able to do something like:

from pyspark.mllib.tree import DecisionTree
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo=categoricalFeatures, impurity='gini', maxDepth=5, maxBins=16)
print model.toDebugString()

and I would get a nice visualization of the decision tree:

DecisionTreeModel classifier of depth 5 with 49 nodes
  If (feature 1 in {0.0})
   If (feature 0 in {0.0})
    If (feature 2 <= 52.0)
     If (feature 3 <= 26.0)
      Predict: 0.0
...

I am trying to port my code to pyspark.ml, but I don't see any way of printing the resulting tree

from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxDepth=5, maxBins=16, impurity='gini')
model = dt.fit(transformedTrainingData)

When I do:

print model

I only get the first line:

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4cbda3dcd0bddd9d4a0b) of depth 5 with 43 nodes

Thoughts on how to get the nice tree output?


Solution

  • I found a solution. It is not elegant and it violates encapsulation and everything you ever learned about object oriented programming, but it works:

    print model._call_java("toDebugString")
    
    DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4c3bb548827f07c590e6) of depth 5 with 49 nodes
      If (feature 1 in {0.0})
       If (feature 0 in {1.0,2.0})
        If (feature 2 <= 5.0)
         If (feature 3 <= 26.0)
          Predict: 1.0
         Else (feature 3 > 26.0)
          If (feature 0 in {2.0})
    ...