Let's say I have a dataframe with features x..
and outcomes y
:
import pandas as pd
def crossing(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
return pd.merge(df1.assign(key=1), df2.assign(key=1), on='key').drop(columns='key')
def crossing_many(*args):
from functools import reduce
return reduce(crossing, args)
df = crossing_many(
pd.DataFrame({'x1': ['A', 'B', 'C']}),
pd.DataFrame({'x2': ['X', 'Y', 'Z']}),
pd.DataFrame({'x3': ['xxx', 'yyy', 'zzz']}),
).assign(y = lambda d: np.random.choice([0, 1], size=len(d)))
I can plot a tree with bigtree
package quite simply:
from bigtree import dataframe_to_tree
def view_pydot(pdot):
from IPython.display import Image, display
plt = Image(pdot.create_png())
display(plt)
tree = (
df
.assign(y=lambda d: d['y'].astype('str'))
.assign(root='Everyone')
.assign(path=lambda d: d[['root'] + features + ['y']].agg('/'.join, axis=1))
.pipe(dataframe_to_tree, path_col='path')
)
view_pydot(tree_to_dot(tree))
Tree is way complex than it could be. I want to iteratively "join" branches/nodes that have the same leave node - on all levels. For example, something like that:
Basically I want to create as simple tree as possible so person will be able to use it in the sense, IF x1=A AND x2=X THEN 1 (so come to the decision through the shortest path possible). It would also make sense to remove nodes that cover all possible values for this features (for example xxx|yyy|zzz
). Thanks!
Thanks for using bigtree
! This feels like an algorithm question where you need to do some tree manipulation here. A solution is to combine sibling nodes in a tree if they have the same descendants.
For example, X/xxx
with descendant 0
(X/xxx/0
) and its sibling X/zzz
with descendant 0
(X/zzz/0
) can be combined to X/xxx|zzz
with descendant 0
.
The code snippet below is separated into the setup and the solution.
# Set up
from bigtree import list_to_tree, tree_to_dot
root = list_to_tree([
"Everyone/A/X/xxx/0",
"Everyone/A/X/yyy/1",
"Everyone/A/X/zzz/0",
"Everyone/A/Y/xxx/1",
"Everyone/A/Y/yyy/0",
"Everyone/A/Y/zzz/1",
"Everyone/A/Z/xxx/0",
"Everyone/A/Z/yyy/1",
"Everyone/A/Z/zzz/0",
"Everyone/B/X/xxx/1",
"Everyone/B/X/yyy/0",
"Everyone/B/X/zzz/0",
"Everyone/B/Y/xxx/1",
"Everyone/B/Y/yyy/0",
"Everyone/B/Y/zzz/0",
"Everyone/B/Z/xxx/0",
"Everyone/B/Z/yyy/1",
"Everyone/B/Z/zzz/1",
"Everyone/C/X/xxx/1",
"Everyone/C/X/yyy/0",
"Everyone/C/X/zzz/0",
"Everyone/C/Y/xxx/0",
"Everyone/C/Y/yyy/1",
"Everyone/C/Y/zzz/0",
"Everyone/C/Z/xxx/0",
"Everyone/C/Z/yyy/1",
"Everyone/C/Z/zzz/1",
])
tree_to_dot(root).write_png("before.png")
# Solution
from bigtree import Node, postorder_iter, tree_to_dot
from typing import List
def get_descendant_paths(_node: Node) -> List[str]:
return sorted([descendant.path_name.removeprefix(_node.path_name) for descendant in _node.descendants])
# Iterate through the tree in the bottom-up fashion
for node in postorder_iter(root):
for node_sib in node.siblings:
# If nodes have the same descendants
if get_descendant_paths(node) == get_descendant_paths(node_sib):
node.name = f"{node.name}|{node_sib.name}"
node_sib.parent = None
tree_to_dot(root).write_png("after.png")
Disclaimer: I'm the author of bigtree
:)