Search code examples
pythonscikit-learndecision-tree

Import Error: cannot import name 'tree' from 'sklearn.tree'


I am on my second day of re-taking Python for the gazillionth time! I am doing a tutorial on ML in Python, using the following code:

import sklearn.tree
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import tree

music_data = pd.read_csv('music.csv')
x = music_data.drop(columns=['genre'])
y = music_data['genre']

model = DecisionTreeClassifier()
model.fit(x,y)

tree.export_graphviz(model, out_file='music-recommender.dot',
                feature_names=['age','gender'],
                class_names= sorted(y.unique()),
                label='all',
                rounded=True,
                filled=True)

I keep getting the following error:

ImportError                               Traceback (most recent call last)
 ~\AppData\Local\Temp/ipykernel_13088/3820271611.py in <module>
      2 import pandas as pd
      3 from sklearn.tree import DecisionTreeClassifier
----> 4 from sklearn.tree import tree
      5 
      6 music_data = pd.read_csv('music.csv')

ImportError: cannot import name 'tree' from 'sklearn.tree' (C:\Anaconda\lib\site-packages\sklearn\tree\__init__.py)

I've tried to find a solution online, but I don't think it's the version of Python/Anaconda because I literally just installed both. I also don't think it's the sklearn.tree since I was able to import DecisionClassifer.


Solution

  • As this answer indicates, you're looking at some older code; this is always a risk with programming. But there's another thing you need to know about your code.

    First off, scikit-learn contains several modules, and almost everything you need from it is in one of those. In my experience, most people import things like this:

    from sklearn.tree import DecisionTreeRegressor   # A regressor class.
    from sklearn.tree import plot_tree               # A helpful function.
    from sklearn.metrics import mean_squared_error   # An evaluation function.
    

    It looks like the tutorial wants something similar to plot_tree(). This new-ish function is much easier to use than the older Graphviz visualization. So unless you really need the DOT file for some reasons, you should be able to do this:

    from sklearn.tree import plot_tree
    
    sklearn.tree.plot_tree(model)
    

    Bottom line: there will probably be more broken things in that material. So if I were you I'd either make a new environment with a version of sklearn matching whatever material you're using... or ditch that material and look for something newer.