Below is a simplified example of a decision tree (dict()
) that I trained in Python:
tree= {'Age': {'> 55': 0.4, '< 18': {'Income': {'high': 0, 'low': 0.2}},
'18-35': 0.25, '36-55': {'Marital_Status': {'single': {'Income':
{'high': 0, 'low': 0.1}}, 'married': 0.05}}}}
The numbers in the leaf nodes (boxes) represent the probability of a class label (e.g. TRUE) appearing in that node. Visually, the tree looks like this:
I am trying to code a generic post-pruning algorithm that consolidates the nodes that have values less than 0.3
to their parent nodes. So the resulting tree with a 0.3
threshold would look like this when plotted:
In the second figure, please note that the Income
node at Age<18
has now been consolidated unto the root node Age
. And the Age=36-55, Marital_Staus
has been consolidated to Age
since the sum of all its leaf nodes (at multiple levels) is less then 0.3.
This is the incomplete pseudo-code I came up with (so far):
def post_prune (dictionary, threshold):
for k in dictionary.keys():
if isinstance(dictionary[k], dict): # interim node
post_prune(dictionary[k], threshold)
else: # leaf node
if dictionary[k]> threshold:
pass
else:
to_do = 'delete this node'
Wanted to post the question since I feel this should have been solved numerous times.
Thank you.
P.S: I am not going to use the end result for classification, so pruning this way (cosmetically) works.
You can try something like this:
def simplify(tree, threshold):
# simplify tree bottom-up
for key, child in tree.items():
if isinstance(child, dict):
tree[key] = simplify(child, threshold)
# all child-nodes are leafs and smaller than threshold -> return max
if all(isinstance(child, str) and float(child) <= threshold
for child in tree.values()):
return max(tree.values(), key=float)
# else return tree itself
return tree
Example:
>>> tree= {'Age': {'> 55': '0.4', '18-35': '0', \
'< 18': {'Income': {'high': '0', 'low': '0.2'}}, \
'36-55': {'Marital_Status': {'single': {'Income': {'high': '0', 'low': '0.1'}}, \
'married': '0.3'}}}}
>>> simplify(tree, 0.2)
{'Age': {'> 55': '0.4', '< 18': '0.2', '18-35': '0',
'36-55': {'Marital_Status': {'single': '0.1', 'married': '0.3'}}}}
Update: Seems like I misunderstood your question: You want the simplified tree to hold the sums of the leafs if their sum is smaller than the threshold! Your suggested edit was slightly off. Try this:
def simplify(tree, threshold):
# simplify tree bottom-up
for key, child in tree.items():
if isinstance(child, dict):
tree[key] = simplify(child, threshold)
# all child-nodes are leafs and sum smaller than threshold -> return sum
if all(isinstance(child, str) for child in tree.values()) \
and sum(map(float, tree.values())) <= threshold:
return str(sum(map(float, tree.values())))
# else return tree itself
return tree