Search code examples
scikit-learndecision-treetree-structure

How to obtain the interval limits from a decision tree with scikit-learn?


Say I am using the titanic dataset, with the variable age only:

import pandas as pd

data = pd.read_csv('https://www.openml.org/data/get_csv/16826755/phpMYEkMl')[["age", "survived"]]
data = data.replace('?', np.nan)
data = data.fillna(0)
print(data)

the result:

         age  survived
0         29         1
1     0.9167         1
2          2         0
3         30         0
4         25         0
...      ...       ...
1304    14.5         0
1305       0         0
1306    26.5         0
1307      27         0
1308      29         0

[1309 rows x 2 columns]

Now I train a decision tree to predict survival from age:

from sklearn.tree import DecisionTreeClassifier
tree_model = DecisionTreeClassifier(max_depth=3)
tree_model.fit(data['age'].to_frame(),data["survived"])

And if I print the structure of the tree:

from sklearn import tree
print(tree.export_text(tree_model))

I obtain:

|--- feature_0 <= 0.08
|   |--- class: 0
|--- feature_0 >  0.08
|   |--- feature_0 <= 8.50
|   |   |--- feature_0 <= 1.50
|   |   |   |--- class: 1
|   |   |--- feature_0 >  1.50
|   |   |   |--- class: 1
|   |--- feature_0 >  8.50
|   |   |--- feature_0 <= 60.25
|   |   |   |--- class: 0
|   |   |--- feature_0 >  60.25
|   |   |   |--- class: 0

These means that the final division for every node is:

0-0.08 ; 0.08-1.50; 1.50-8.50 ; 8.50-60; >60

My question is, how can I capture those limits in an array that looks like this:

[-np.inf, 0.08, 1.5, 8.5, 60, np.inf]

Thank you!


Solution

  • The decision classifier, in this case tree_model has an attribute called tree_ which allows access to low level attributes.

    print(tree_model.tree_.threshold)
    
    array([ 0.08335, -2.     ,  8.5    ,  1.5    , -2.     , -2.     ,
           60.25   , -2.     , -2.     ])
    
    print(tree_model.tree_.feature)
    
    array([ 0, -2,  0,  0, -2, -2,  0, -2, -2], dtype=int64)
    

    The arrays feature and threshold only apply to split nodes. The values for leaf nodes in these arrays are therefore arbitrary.

    To get the division/threshold of a feature, you can filter the threshold using the feature array.

    threshold = tree_model.tree_.threshold
    feature = tree_model.tree_.feature
    feature_threshold = threshold[feature == 0]
    thresholds = sorted(feature_threshold)
    print(thresholds)
    
    [0.08335000276565552, 1.5, 8.5, 60.25]
    

    To have np.inf, you need to add it yourself.

    thresholds = [-np.inf] + thresholds + [np.inf]
    print(thresholds)
    
    [-inf, 0.08335000276565552, 1.5, 8.5, 60.25, inf]
    

    Reference: Understanding the decision tree structure.