Search code examples
pythoniteratortree-traversal

non binary tree traversal iterator


I'm trying to figure out how to use iter and next for a self made tree which has nodes classes as nodes(and root) and each has a list of children. I want to traverse depth first-in order, but only until half the children, then the parent and then the right children.

I created a stack with push\pop and size functionality ,tried iterating by pushing elements into the stack and then when the stack is empty it means we are done, but I can't get it to work.

class node:
    def __init__(self,val=1):
        self.val=val
        self.children=[]
        self.parent=None
    def add_child(self,child):
        self.children.append(child)
        child.parent=self
    def __repr__(self):
        resultstr="Node"+str(self.val)
        return resultstr

class tree:
    def __init__(self,root): # root is of type node
        self.root=root
    def __iter__(self):
        self.currnode=self.root
        self.visitednodes=[]
        self.returnednodes=[]
        self.stack=Stack()
        self.visitednodes.append(self.root)
        self.stack.push(self.root)
        return self

        def __next__(self): # we want to traverse tree in-order until half of children,then us,then the other half
        if self.stack.size==0:
            raise StopIteration
        else:
            u=self.stack.pop()
            if len(u.children)==0:

                self.returnednodes.append(u)
                return u
            else:
                for i in range(len(u.children)):
                    if u.children[i] not in self.visitednodes:
                        self.visitednodes.append(u.children[i])
                        self.stack.push(u.children[i])
                        #break
                        
                for i in range(len(u.children)//2):
                    if u.children[i] not in self.returnednodes:
                        self.returnednodes.append(u.children[i])
                        return u.children[i]
                if u not in self.returnednodes:
                    self.returnednodes.append(u)
                    return u
                for i in range((len(u.children)//2)+1):
                    if u.children[i] not in self.returnednodes:
                        self.returnednodes.append(u.children[i])
                        return u.children[i]
class Stack:
    def __init__(self):
        self.data=[]
        self.size=0
    def push(self,element):
        self.size+=1
        self.data.append(element)
    def pop(self):
        res=self.data.pop(self.size-1)
        self.size-=1
        return res

Here is the code I'm trying to run:

if __name__ == '__main__':
    node6 =node(6)
    node3=node(3)
    node5=node(5)
    node9=node(9)
    node1=node(1)
    node2=node(2)
    node4=node(3)
    node8=node(8)
    node10=node(10)
    node7=node(7)
    node6.add_child(node3)
    node6.add_child(node5)
    node6.add_child(node9)
    node3.add_child(node1)
    node3.add_child(node2)
    node3.add_child(node4)
    node9.add_child(node8)
    node9.add_child(node10)
    node8.add_child(node7)
    mytree=tree(node6)
    for node in mytree:
        print(node)

The output I get: Node3 Node8 Node10 Node7 Node7 Node5 Node1 Node4 Node2 Node1 Which seems odd, i would expect to at least start with node 1.

The wanted output: enter image description here Node1Node2Node3Node4Node5Node6Node7Node8Node9Node10


Solution

  • class CompanyTree:

    def __init__(self, root=None):
        self.root = root
    
    def set_root(self, root):
        self.root=root
    
    def __iter__(self):
        self.visitednodes=[]
        self.stack = Stack()
        self.stack.push(self.root)
        return self
    
    def __next__(self):
        if self.stack.size == 0:
            raise StopIteration
        else:
            u = self.stack.pop()
            while u not in self.visitednodes:
                self.visitednodes.append(u)
                if len(u.children) == 0:
                    return u
                else:
                    for i in range((len(u.children) + 1) // 2, len(u.children)):
                        self.stack.push(u.children[i])
    
                    self.stack.push(u)
    
                    for i in range(((len(u.children) + 1) // 2) - 1, -1, -1):
                        self.stack.push(u.children[i])
    
                    u = self.stack.pop()
            return u