Search code examples
pythonbinary-treenetworkx

Edgelist from a binary tree


I'd like to make a pretty plot of my binary tree.

Here's my custom BinaryTree class:

class BinaryTree():

   def __init__(self, data):
      self.data = data
      self.right = None
      self.left = None

Now, in order to plot this graph I'll use the networkx library and so I need to convert my graph to a networkx object and then plot it using graphviz. The problem is the edge list: in order to build my new object, I need the edges.

For example given a binary tree like in the following figure. enter image description here

I need to retrieve the edge list. Would be something like this:

[(0,1),(0,2),(2,3),(2,4)]

Notice that in my case I don't have id on node. And so how can I do this? I believe it might be some recursive function taking account on the depth but I'm having some difficulties so a little help is appreciated. ;)

EDIT

Thanks for the answers. But I found a solution by myself that works well..:P Here it is:

def edgelist(node, output, id=0):

    if node is None or isinstance(node, bt.Leaf):
         return output

    if node.left:
         output.append((id, id*2+1))

    if node.right:
         output.append((id, id*2+2))

    edgelist(node.left, output, id*2+1)
    edgelist(node.right, output, id*2+2)

    return output

Solution

  • Here is one way you could modify the BinaryTree class to dump an edgelist:

    import networkx as nx
    import itertools as IT
    import matplotlib.pyplot as plt
    
    class BinaryTree(object):
       def __init__(self, data):
          self.data = data
          self.right = None
          self.left = None
          self.name = None
       def edgelist(self, counter = IT.count().next):
           self.name = counter() if self.name is None else self.name
           for node in (self.left, self.right):       
               if node:
                   node.name = counter() if node.name is None else node.name
                   yield (self.name, node.name)
           for node in (self.left, self.right):
               if node:
                   for n in node.edgelist(counter):
                       yield n
    
    tree = [BinaryTree(i) for i in range(5)]        
    tree[0].left = tree[1]
    tree[0].right = tree[2]
    tree[2].left = tree[3]
    tree[2].right = tree[4]
    
    edgelist = list(tree[0].edgelist())
    print(edgelist)   
    
    G = nx.Graph(edgelist)
    nx.draw_spectral(G)
    plt.show()
    

    yields

    [(0, 1), (0, 2), (2, 3), (2, 4)]
    

    enter image description here