Search code examples
pythonmachine-learningscikit-learndecision-treerandom-forest

How to extract the decision rules from scikit-learn decision-tree?


Can I extract the underlying decision-rules (or 'decision paths') from a trained tree in a decision tree as a textual list?

Something like:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Solution

  • I believe that this answer is more correct than the other answers here:

    from sklearn.tree import _tree
    
    def tree_to_code(tree, feature_names):
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        print "def tree({}):".format(", ".join(feature_names))
    
        def recurse(node, depth):
            indent = "  " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                print "{}if {} <= {}:".format(indent, name, threshold)
                recurse(tree_.children_left[node], depth + 1)
                print "{}else:  # if {} > {}".format(indent, name, threshold)
                recurse(tree_.children_right[node], depth + 1)
            else:
                print "{}return {}".format(indent, tree_.value[node])
    
        recurse(0, 1)
    

    This prints out a valid Python function. Here's an example output for a tree that is trying to return its input, a number between 0 and 10.

    def tree(f0):
      if f0 <= 6.0:
        if f0 <= 1.5:
          return [[ 0.]]
        else:  # if f0 > 1.5
          if f0 <= 4.5:
            if f0 <= 3.5:
              return [[ 3.]]
            else:  # if f0 > 3.5
              return [[ 4.]]
          else:  # if f0 > 4.5
            return [[ 5.]]
      else:  # if f0 > 6.0
        if f0 <= 8.5:
          if f0 <= 7.5:
            return [[ 7.]]
          else:  # if f0 > 7.5
            return [[ 8.]]
        else:  # if f0 > 8.5
          return [[ 9.]]
    

    Here are some stumbling blocks that I see in other answers:

    1. Using tree_.threshold == -2 to decide whether a node is a leaf isn't a good idea. What if it's a real decision node with a threshold of -2? Instead, you should look at tree.feature or tree.children_*.
    2. The line features = [feature_names[i] for i in tree_.feature] crashes with my version of sklearn, because some values of tree.tree_.feature are -2 (specifically for leaf nodes).
    3. There is no need to have multiple if statements in the recursive function, just one is fine.