Search code examples
scikit-learngraphvizpydot

Plotting decision tree, graphvizm pydotplus


I'm following the tutorial for decision tree on scikit documentation. I have pydotplus 2.0.2 but it is telling me that it does not have write method - error below. I've been struggling for a while with it now, any ideas, please? Many thanks!

from sklearn import tree
from sklearn.datasets import load_iris

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

from IPython.display import Image

dot_data = tree.export_graphviz(clf, out_file=None)
import pydotplus

graph = pydotplus.graphviz.graph_from_dot_data(dot_data)

Image(graph.create_png())

and my error is

    /Users/air/anaconda/bin/python /Users/air/PycharmProjects/kiwi/hemr.py
Traceback (most recent call last):
  File "/Users/air/PycharmProjects/kiwi/hemr.py", line 10, in <module>
    dot_data = tree.export_graphviz(clf, out_file=None)
  File "/Users/air/anaconda/lib/python2.7/site-packages/sklearn/tree/export.py", line 375, in export_graphviz
    out_file.write('digraph Tree {\n')
AttributeError: 'NoneType' object has no attribute 'write'

Process finished with exit code 1

----- UPDATE -----

Using the fix with out_file, it throws another error:

 Traceback (most recent call last):
  File "/Users/air/PycharmProjects/kiwi/hemr.py", line 13, in <module>
    graph = pydotplus.graphviz.graph_from_dot_data(dot_data)
  File "/Users/air/anaconda/lib/python2.7/site-packages/pydotplus/graphviz.py", line 302, in graph_from_dot_data
    return parser.parse_dot_data(data)
  File "/Users/air/anaconda/lib/python2.7/site-packages/pydotplus/parser.py", line 548, in parse_dot_data
    if data.startswith(codecs.BOM_UTF8):
AttributeError: 'NoneType' object has no attribute 'startswith'

---- UPDATE 2 -----

Also, se my own answer below which solves another problem


Solution

  • The problem is that you are setting the parameter out_file to None.

    If you look at the documentation, if you set it at None it returns the string file directly and does not create a file. And of course a string does not have a write method.

    Therefore, do as follows :

    dot_data = tree.export_graphviz(clf)
    graph = pydotplus.graphviz.graph_from_dot_data(dot_data)