Search code examples
pythondatabricksazure-databricksdtreeviz

Python Databricks cannot visualise dtreeviz decision tree


I need to visualize a decision tree in dtreeviz in Databricks. The code seems to be working fine. However, instead of showing the decision tree it throws the following:

Out[23]: <dtreeviz.trees.DTreeViz at 0x7f5b27a91160>

Running the following code:

import pandas as pd
from sklearn import preprocessing, tree
from dtreeviz.trees import dtreeviz

Things = {'Feature01': [3,4,5,0], 
          'Feature02': [4,5,6,0], 
          'Feature03': [1,2,3,8], 
          'Target01': ['Red','Blue','Teal','Red']}
df = pd.DataFrame(Things,
                  columns= ['Feature01', 'Feature02', 
                            'Feature02', 'Target01']) 

label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(df.Target01)
df['target'] = label_encoder.transform(df.Target01)

classifier = tree.DecisionTreeClassifier()
classifier.fit(df.iloc[:,:3], df.target)

dtreeviz(classifier,
         df.iloc[:,:3],
         df.target,
         target_name='toy',
         feature_names=df.columns[0:3],
         class_names=list(label_encoder.classes_)
         )

Solution

  • if you look into dtreeviz documentation you'll see that dtreeviz method just creates an object, and then you need to use function like .view() to show it. On Databricks, view won't work, but you can use .svg() method to generate output as SVG, and then use displayHTML function to show it. Following code:

    viz = dtreeviz(classifier,
      ...)
    displayHTML(viz.svg())
    

    will give you desired output:

    enter image description here

    P.S. You need to have the dot command-line tool to generate output. It could be installed by executing in a cell of the notebook:

    %sh apt-get install -y graphviz