Search code examples
pythongraphvizdecision-treepydotplus

What is the solid black rectangle adjacent to the decision tree?


I adapted this code from https://www.dasca.org/world-of-big-data/article/know-how-to-create-and-visualize-a-decision-tree-with-python.

I removed two arguments to the DecisionTreeClassifier constructor, min_impurity_split=None and presort=False, but othewise the code is the same as I found it.

import sklearn.datasets as datasets
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
#from sklearn.externals.six import StringIO  
from IPython.display import Image  
from sklearn.tree import export_graphviz
import pydotplus
from six import StringIO
iris=datasets.load_iris()
df=pd.DataFrame(iris.data, columns=iris.feature_names)
y=iris.target
dtree=DecisionTreeClassifier()
dtree.fit(df,y)

# Limit max depth
model = RandomForestClassifier(max_depth = 3, n_estimators=10)
# Train
model.fit(iris.data, iris.target)
# Extract single tree
estimator_limited = model.estimators_[5]
estimator_limited

# Removed  min_impurity_split=None and presort=False because they caused "unexpected keyword argument" errors
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
            max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, 
            random_state=1538259045, splitter='best')
# No max depth
model = RandomForestClassifier(max_depth = None, n_estimators=10)
model.fit(iris.data, iris.target)
estimator_nonlimited = model.estimators_[5]

dot_data = StringIO()
export_graphviz(dtree, out_file=dot_data,  
                filled=True, rounded=True,
                special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())

The decision tree looks like this: enter image description here


Solution

  • The black rectangle is a bug in pydotplus. You can fix it by the following modification of the corresponding line in your code:

    graph = pydotplus.graph_from_dot_data(dot_data.getvalue().replace("\n", ""))