Search code examples
pythonrecursiontreetraversal

How to build a nested tree structure from a list of adjacencies?


Considering that I have:

  • a list of adjacent keys (child - parent) named A
  • a tree class named Tree storing its own node key (integer) and children (classes)

A = [(61, 66), (50, 61), (68, 61), (33, 61), (57, 66), (72, 66), (37, 68), (71, 33), (6, 50), (11, 37), (5, 37)]

class Tree:
    def __init__(self, node, *children):
        self.node = node
        if children: self.children = children
        else: self.children = []
    
    def __str__(self): 
        return "%s" % (self.node)
    def __repr__(self):
        return "%s" % (self.node)

    def __getitem__(self, k):
        if isinstance(k, int) or isinstance(k, slice): 
            return self.children[k]
        if isinstance(k, str):
            for child in self.children:
                if child.node == k: return child

    def __iter__(self): return self.children.__iter__()

    def __len__(self): return len(self.children)

How can I build a Tree object such that it encapsulates all the inner trees in accordance with the adjacencies ? (like the following)

t = Tree(66, 
        Tree(72), 
        Tree(57), 
        Tree(61, 
            Tree(33,
                Tree(71)), 
            Tree(50, 
                Tree(6)), 
            Tree(68, 
                Tree(37, 
                    Tree(11), Tree(5)))))

I was thinking about creating the tree in a recursive way but I can not figure out how to traverse and populate it properly. Here is my failed attempt:

from collections import defaultdict

# Create a dictionary: key = parent, values = children
d = defaultdict(list)
for child, parent in A:
    d[parent].append(child)

# Failed attempt
def build_tree(k):    
    if k in d:
        tree = Tree(k, d[k]) #1st issue: should input a Tree() as 2nd parameter
        for child in d[k]:
            build_tree(child) #2nd issue: should populate tree, not iterate recursively over children keys

#I know that the root node is 66.
full_tree = build_tree(66)
        

Solution

  • You mention two issues in this piece of code:

        tree = Tree(k, d[k]) #1st issue: should input a Tree() as 2nd parameter
        for child in d[k]:
            build_tree(child) #2nd issue: should populate tree, not iterate recursively over children keys
    

    You can solve them by essentially moving the for loop into the second argument, in the form of list comprehension and splashing that list so they become arguments. And then make sure your recursive function returns the created tree:

        return Tree(k, 
            *[build_tree(child) for child in d[k]]
        )
    

    More ideas

    Unrelated to your question, but here are some more ideas you could use.

    • It would be advisable to make your code a function to which you can pass A as argument, so that also the dictionary's scope is just local to that function and does not litter the global scope.

    • As this feature is strongly related to the Tree class, it would be nice to define it as a static or class method within the class.

    • When you have the (child, parent) tuples for the tree, then these implicitly define which node is the root, so you could omit passing the literal 66 to your function. That function should be able to find out which is the root by itself. While creating the dictionary it can also collect which nodes have a parent. The root is then the node that is not in that collection.

    So taking all that together you would have this:

    from collections import defaultdict
    
    class Tree:
        def __init__(self, node, *children):
            self.node = node
            self.children = children if children else []
        
        def __str__(self): 
            return "%s" % (self.node)
        
        def __repr__(self):
            return "%s" % (self.node)
    
        def __getitem__(self, k):
            if isinstance(k, int) or isinstance(k, slice): 
                return self.children[k]
            if isinstance(k, str):
                for child in self.children:
                    if child.node == k:
                        return child
    
        def __iter__(self):
            return self.children.__iter__()
    
        def __len__(self):
            return len(self.children)
    
        @classmethod
        def from_pairs(Cls, pairs):
            # Turn pairs into nested dictionary
            d = defaultdict(list)
            children = set()
            for child, parent in pairs:
                d[parent].append(child)
                # collect nodes that have a parent
                children.add(child)
            
            # Find root: it does not have a parent
            root = next(parent for parent in d if parent not in children)
    
            # Build nested Tree instances recursively from the dictionary
            def subtree(k):
                return Cls(k, *[subtree(child) for child in d[k]])
    
            return subtree(root)
    
    # Sample run
    A = [(61, 66), (50, 61), (68, 61), (33, 61), (57, 66), (72, 66), (37, 68), (71, 33), (6, 50), (11, 37), (5, 37)]
    
    tree = Tree.from_pairs(A)