Search code examples
pythonalgorithmrecursiontreebehavior-tree

Get maximum tree depth for each node


Suppose we have a tree and each node can have multiple children while the children can have more children etc.

Take this tree as an example:

- Node 1
    - Node 1.1
    - Node 1.2
        - Node 1.2.1
            - Node 1.2.1.1
            - Node 1.2.1.2
        - Node 1.2.2
    - Node 1.3
        - Node 1.3.1

Node 1 has a depth = 0 (root)

Node 1.1, 1.2, 1.3 has depth = 1 and so on

For each node I want to calculate the maximum depth it can reach. For example, Node 1 max. depth is 3 (the tree goes as deep as node 1.2.1.1). While Node 1.3 max depth = 1 (the subtree reaches as deep as node 1.3.1)

Now what I could do is to create a function, that takes a subtree and counts to the deepest node, and returns the depth value. But this would require calling that function for each and every node which seems very inefficient to me.

I want to create the tree and compute the max depths in one go.

I'm keeping the code very simple as there are a lot of other operations my function contains (such as generating new children as I'm building the tree from scratch, but I left these parts out for simplicity). But basically, I'm going through the tree like this:

def create_behavior_tree(depth, max_depth, behavior_tree)
  for child in behavior_tree.children:
    if depth > max_depth:
      max_depth = depth
    if len(child) > 0: # Expanding node
      max_depth = create_behavior_tree(depth + 1, max_depth, child)
      child.max_depth = max_depth  # problem: stores the values in "reverse"
    else: # Single node without children
      child.max_depth = depth


create_behavior_tree(1, 0, Tree)

However, when I do that, I can not reach the newest max_depth value for the outer node, it can only be reached within the innermost node (since this is recursion). So this would compute: Node 1 max depth = 0, Node 1.2 max depth = 1, Node 1.2.1 max depth = 2 etc. It's actually in reverse.

So, perhaps I need to use global variables here?

EDIT – A more detailed version of my function

def create_behavior_tree(depth, behavior_tree, children, max_tree_depth, node_count):
    if depth <= max_tree_depth:
        for child in children:
            # add behavior node
            if type(child) == Behaviour:
                behavior_tree.add_children([child])
                node_count += 1  # counting total nodes in tree
            # add composite node
            if type(child) == Composite:
                # replace by behavior node (reached max tree depth)
                if depth == max_tree_depth:
                    node = create_behaviour_node()
                    behavior_tree.add_children([node])
                    node_count += 1
                else:
                    behavior_tree.add_children([child])
                    node_count += 1
                    # further expand behavior tree 
                    children = create_random_children(size=3)
                    _, node_count = create_behavior_tree(depth + 1, node, children, max_tree_depth, node_count)
    return behavior_tree, node_count

random_children = create_random_children(size=3)  # Create 3 random children
root = py_trees.composites.Selector("Selector")

create_behavior_tree(1, root, random_children, 5, 0)

Solution

  • Taking advantage of the self-similarity of recursion, this is the same as any other one-pass O(n) depth/tree height algorithm, except set the maximum depth property for each node as you work back up the call stack:

    def set_all_depths(tree):
        if not tree:
            return 0
    
        tree.max_depth = (
            max(map(set_all_depths, tree.children)) 
            if tree.children else 0
        )
        return tree.max_depth + 1
    
    def print_tree(tree, depth=0):
        if tree:
            print("    " * depth + 
                  f"{tree.val} [max_depth: {tree.max_depth}]")
    
            for child in tree.children:
                print_tree(child, depth + 1)
    
    
    class TreeNode:
        def __init__(self, val, children=None, max_depth=None):
            self.val = val
            self.children = children or []
            self.max_depth = max_depth
    
    
    if __name__ == "__main__":
        tree = TreeNode(
            "1",
            [
                TreeNode("1.1"),
                TreeNode(
                    "1.2",
                    [
                        TreeNode(
                            "1.2.1",
                            [
                                TreeNode("1.2.1.1"),
                                TreeNode("1.2.1.2"),
                            ],
                        ),
                        TreeNode("1.2.2"),
                    ],
                ),
                TreeNode(
                    "1.3",
                    [
                        TreeNode("1.3.1"),
                    ],
                ),
            ],
        )
        set_all_depths(tree)
        print_tree(tree)
    

    Output:

    1 [max_depth: 3]
        1.1 [max_depth: 0]
        1.2 [max_depth: 2]
            1.2.1 [max_depth: 1]
                1.2.1.1 [max_depth: 0]
                1.2.1.2 [max_depth: 0]
            1.2.2 [max_depth: 0]
        1.3 [max_depth: 1]
            1.3.1 [max_depth: 0]
    

    Based on the follow-up comments, if you want to make the tree and assign max depths in one go, one way is to pick out the max depth from among the children and assign it to each node on the way back up the call stack.

    Here's a library-agnostic proof of concept, since I don't have the classes you have in your new example:

    import random
    from random import randint
    
    
    def make_random_tree(max_depth=4):
        if max_depth <= 0:
            return TreeNode(Id.make(), max_depth=0)
    
        node = TreeNode(Id.make())
        node.children = [
            make_random_tree(max_depth - 1) for _ in range(randint(0, 4))
        ]
        node.max_depth = 1 + (
            max([x.max_depth for x in node.children])
            if node.children else 0
        )
        return node
    
    def print_tree(tree, depth=0):
        if tree:
            print("    " * depth + 
                  f"{tree.val} [max_depth: {tree.max_depth}]")
    
            for child in tree.children:
                print_tree(child, depth + 1)
    
    
    class TreeNode:
        def __init__(self, val, children=None, max_depth=None):
            self.val = val
            self.children = children or []
            self.max_depth = max_depth
    
    
    class Id:
        identifier = 0
    
        def make():
            Id.identifier += 1
            return Id.identifier
    
    
    if __name__ == "__main__":
        random.seed(1)
        tree = make_random_tree()
        print_tree(tree)
    

    Output:

    1 [max_depth: 4]
        2 [max_depth: 3]
            3 [max_depth: 1]
            4 [max_depth: 2]
                5 [max_depth: 1]
                6 [max_depth: 1]
                    7 [max_depth: 0]
                    8 [max_depth: 0]
                    9 [max_depth: 0]
            10 [max_depth: 2]
                11 [max_depth: 1]
                    12 [max_depth: 0]
                    13 [max_depth: 0]
                    14 [max_depth: 0]
                15 [max_depth: 1]
                    16 [max_depth: 0]
                    17 [max_depth: 0]
                    18 [max_depth: 0]
                19 [max_depth: 1]
                    20 [max_depth: 0]
            21 [max_depth: 1]
    

    That said, assigning depths in a second pass is still O(n) so it might be premature optimization to do it all in one pass unless you have truly large or performance-critical code.