Search code examples
python-2.7dictionarydecision-treepruning

Python Dictionary - Consolidating Leaf Nodes below a threshold


Below is a simplified example of a decision tree (dict()) that I trained in Python:

tree= {'Age': {'> 55': 0.4, '< 18': {'Income': {'high': 0, 'low': 0.2}}, 
               '18-35': 0.25, '36-55': {'Marital_Status': {'single': {'Income': 
               {'high': 0, 'low': 0.1}}, 'married': 0.05}}}}

The numbers in the leaf nodes (boxes) represent the probability of a class label (e.g. TRUE) appearing in that node. Visually, the tree looks like this:

enter image description here

I am trying to code a generic post-pruning algorithm that consolidates the nodes that have values less than 0.3 to their parent nodes. So the resulting tree with a 0.3 threshold would look like this when plotted:

enter image description here

In the second figure, please note that the Income node at Age<18 has now been consolidated unto the root node Age. And the Age=36-55, Marital_Staus has been consolidated to Age since the sum of all its leaf nodes (at multiple levels) is less then 0.3.

This is the incomplete pseudo-code I came up with (so far):

def post_prune  (dictionary, threshold):

    for k in dictionary.keys():

        if isinstance(dictionary[k], dict): # interim node

            post_prune(dictionary[k], threshold)

        else: # leaf node

            if dictionary[k]> threshold:
                pass
            else:
                to_do = 'delete this node'

Wanted to post the question since I feel this should have been solved numerous times.

Thank you.

P.S: I am not going to use the end result for classification, so pruning this way (cosmetically) works.


Solution

  • You can try something like this:

    def simplify(tree, threshold):
        # simplify tree bottom-up
        for key, child in tree.items():
            if isinstance(child, dict):
                tree[key] = simplify(child, threshold)
        # all child-nodes are leafs and smaller than threshold -> return max
        if all(isinstance(child, str) and float(child) <= threshold 
               for child in tree.values()):
            return max(tree.values(), key=float)
        # else return tree itself
        return tree
    

    Example:

    >>> tree= {'Age': {'> 55': '0.4', '18-35': '0', \
                       '< 18': {'Income': {'high': '0', 'low': '0.2'}}, \
                       '36-55': {'Marital_Status': {'single': {'Income': {'high': '0', 'low': '0.1'}}, \
                                                    'married': '0.3'}}}}
    >>> simplify(tree, 0.2)
    {'Age': {'> 55': '0.4', '< 18': '0.2', '18-35': '0', 
             '36-55': {'Marital_Status': {'single': '0.1', 'married': '0.3'}}}}
    

    Update: Seems like I misunderstood your question: You want the simplified tree to hold the sums of the leafs if their sum is smaller than the threshold! Your suggested edit was slightly off. Try this:

    def simplify(tree, threshold):
        # simplify tree bottom-up
        for key, child in tree.items():
            if isinstance(child, dict):
                tree[key] = simplify(child, threshold)
        # all child-nodes are leafs and sum smaller than threshold -> return sum
        if all(isinstance(child, str) for child in tree.values()) \
               and sum(map(float, tree.values())) <= threshold:
            return str(sum(map(float, tree.values())))
        # else return tree itself
        return tree