Search code examples
pythondata-structurestreedecision-treetreelib

Iterating through tree datastructure using Treelib (Python)


I created some nodes by a Node-class and added them to a tree using Treelib.

class Node(object):
    def __init__(self, id, label, parent, type):
        self.id = id
        self.label = label
        self.parent = parent
        self.type = type

node = Node(id, label, parent, type)

if id == 'root':
   tree.create_node(tag=id, identifier=id, data=node)
else:
   tree.create_node(tag=id, identifier=id, parent=parent, data=node)

By calling tree.show() I got a great overview of the tree. Now I want to iterate through the tree and get the data of each node defined earlier. (not only a single property by tree.show(data_property=""))

Do you got any ideas of how to work with the defined data?

My final goal is to calculate the tree structure like a decision tree and I dont find a good way using Treelib so far.


Solution

  • First some remarks:

    • I would use a different name for the class Node; as also TreeLib defines a Node class.
    • As you use TreeLib there should not be a need to maintain parent-references in your own class instances. This is something that TreeLib manages for you already.

    You can iterate over the nodes with the all_nodes_itr method, which will give you a TreeLib Node instance in each iteration. You can then access TreeLib attributes such as identifier or parent. For your own attributes, access the data attribute, and then the attribute you want to see (like label)

    Here is a simplified script:

    class MyNode(object):
        def __init__(self, id, label):
            self.id = id
            self.label = label
    
    
    from treelib import Tree
    tree = Tree()
    
    def add_node(id, label, parent=None):
        node = MyNode(id, label)
        tree.create_node(tag=id, identifier=id, data=node, parent=parent)
    
    add_node("world", "World")
    add_node("north-america", "North America", "world")
    add_node("europe", "Europe", "world")
    
    for node in  tree.all_nodes_itr():
        print(node.identifier, node.data.label)