Search code examples
pythontree

Iteratively join tree branches/nodes that have the same leaf values


Let's say I have a dataframe with features x.. and outcomes y:

import pandas as pd

def crossing(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
    return pd.merge(df1.assign(key=1), df2.assign(key=1), on='key').drop(columns='key')

def crossing_many(*args):
    from functools import reduce
    return reduce(crossing, args)

df = crossing_many(
    pd.DataFrame({'x1': ['A', 'B', 'C']}),
    pd.DataFrame({'x2': ['X', 'Y', 'Z']}),
    pd.DataFrame({'x3': ['xxx', 'yyy', 'zzz']}),
).assign(y = lambda d: np.random.choice([0, 1], size=len(d)))

I can plot a tree with bigtree package quite simply:

from bigtree import dataframe_to_tree
def view_pydot(pdot):
    from IPython.display import Image, display
    plt = Image(pdot.create_png())
    display(plt)

tree = (
    df
    .assign(y=lambda d: d['y'].astype('str'))
    .assign(root='Everyone')
    .assign(path=lambda d: d[['root'] + features + ['y']].agg('/'.join, axis=1))
    .pipe(dataframe_to_tree, path_col='path')
)

view_pydot(tree_to_dot(tree))

I get something like: enter image description here

Tree is way complex than it could be. I want to iteratively "join" branches/nodes that have the same leave node - on all levels. For example, something like that:

enter image description here

Basically I want to create as simple tree as possible so person will be able to use it in the sense, IF x1=A AND x2=X THEN 1 (so come to the decision through the shortest path possible). It would also make sense to remove nodes that cover all possible values for this features (for example xxx|yyy|zzz). Thanks!


Solution

  • Thanks for using bigtree! This feels like an algorithm question where you need to do some tree manipulation here. A solution is to combine sibling nodes in a tree if they have the same descendants.

    For example, X/xxx with descendant 0 (X/xxx/0) and its sibling X/zzz with descendant 0 (X/zzz/0) can be combined to X/xxx|zzz with descendant 0.

    The code snippet below is separated into the setup and the solution.

    # Set up
    from bigtree import list_to_tree, tree_to_dot
    
    root = list_to_tree([
        "Everyone/A/X/xxx/0",
        "Everyone/A/X/yyy/1",
        "Everyone/A/X/zzz/0",
        "Everyone/A/Y/xxx/1",
        "Everyone/A/Y/yyy/0",
        "Everyone/A/Y/zzz/1",
        "Everyone/A/Z/xxx/0",
        "Everyone/A/Z/yyy/1",
        "Everyone/A/Z/zzz/0",
        "Everyone/B/X/xxx/1",
        "Everyone/B/X/yyy/0",
        "Everyone/B/X/zzz/0",
        "Everyone/B/Y/xxx/1",
        "Everyone/B/Y/yyy/0",
        "Everyone/B/Y/zzz/0",
        "Everyone/B/Z/xxx/0",
        "Everyone/B/Z/yyy/1",
        "Everyone/B/Z/zzz/1",
        "Everyone/C/X/xxx/1",
        "Everyone/C/X/yyy/0",
        "Everyone/C/X/zzz/0",
        "Everyone/C/Y/xxx/0",
        "Everyone/C/Y/yyy/1",
        "Everyone/C/Y/zzz/0",
        "Everyone/C/Z/xxx/0",
        "Everyone/C/Z/yyy/1",
        "Everyone/C/Z/zzz/1",
    ])
    tree_to_dot(root).write_png("before.png")
    

    Before: Before

    # Solution
    from bigtree import Node, postorder_iter, tree_to_dot
    from typing import List
    
    def get_descendant_paths(_node: Node) -> List[str]:
        return sorted([descendant.path_name.removeprefix(_node.path_name) for descendant in _node.descendants])
    
    # Iterate through the tree in the bottom-up fashion
    for node in postorder_iter(root):
        for node_sib in node.siblings:
            # If nodes have the same descendants
            if get_descendant_paths(node) == get_descendant_paths(node_sib):
                node.name = f"{node.name}|{node_sib.name}"
                node_sib.parent = None
    
    tree_to_dot(root).write_png("after.png")
    

    After: After

    Disclaimer: I'm the author of bigtree :)