I need to visualize a decision tree in dtreeviz in Databricks. The code seems to be working fine. However, instead of showing the decision tree it throws the following:
Out[23]: <dtreeviz.trees.DTreeViz at 0x7f5b27a91160>
Running the following code:
import pandas as pd
from sklearn import preprocessing, tree
from dtreeviz.trees import dtreeviz
Things = {'Feature01': [3,4,5,0],
'Feature02': [4,5,6,0],
'Feature03': [1,2,3,8],
'Target01': ['Red','Blue','Teal','Red']}
df = pd.DataFrame(Things,
columns= ['Feature01', 'Feature02',
'Feature02', 'Target01'])
label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(df.Target01)
df['target'] = label_encoder.transform(df.Target01)
classifier = tree.DecisionTreeClassifier()
classifier.fit(df.iloc[:,:3], df.target)
dtreeviz(classifier,
df.iloc[:,:3],
df.target,
target_name='toy',
feature_names=df.columns[0:3],
class_names=list(label_encoder.classes_)
)
if you look into dtreeviz
documentation you'll see that dtreeviz
method just creates an object, and then you need to use function like .view()
to show it. On Databricks, view
won't work, but you can use .svg()
method to generate output as SVG, and then use displayHTML function to show it. Following code:
viz = dtreeviz(classifier,
...)
displayHTML(viz.svg())
will give you desired output:
P.S. You need to have the dot
command-line tool to generate output. It could be installed by executing in a cell of the notebook:
%sh apt-get install -y graphviz