Search code examples
pythonbinary-search-treepretty-print

pretty printing a binary search tree


I am trying to print a binary search tree in "graphical" way (The prettyPrint function does that). My node class looks like this:

class Node:
    def __init__(self, dataValue):
        self.dataValue = dataValue
        self.leftChild = None
        self.rightChild = None

My tree class looks something like this:

class binary_search_tree(Node):
    def __init__(self):
        self.root = None

    def insert(self, value):
        if self.root is None:
            self.root = Node(value)
        else:
            self._insert(value, self.root)

    def _insert(self, value, cur_node):
        if value < cur_node.dataValue:
            if cur_node.leftChild is None:
                cur_node.leftChild = Node(value)
            else:
                self._insert(value, cur_node.leftChild)
        if value > cur_node.dataValue:
            if cur_node.rightChild is None:
                cur_node.rightChild = Node(value)
            else:
                self._insert(value, cur_node.rightChild)

    def print_tree(self, transversal_type):
        if transversal_type == "preorder":
            return self._print_preorder(self.root, "")
        elif transversal_type == "inorder":
            return self._print_inorder(self.root, "")
        elif transversal_type == "postorder":
            return self._print_postorder(self.root, "")
        else:
            print("{} does not exist".format(transversal_type))
            return False

    def _print_tree(self, root, indent, transversal = ""):

        if root is not None:
            self._print_tree(root.rightChild, indent + "   ")
            transversal += indent + str(root.dataValue)
            self._print_tree(root.leftChild, indent + "   ")
        return transversal


    def _print_preorder(self, start, transversal):
        # Root -> Left -> Right
        if start:
            transversal += (str(start.dataValue) + " - ")
            transversal = self._print_preorder(start.leftChild, transversal)
            transversal = self._print_preorder(start.rightChild, transversal)
        return transversal

    def _print_inorder(self, start, transversal):
        #Left -> Root -> Right
        if start:
            transversal = self._print_inorder(start.leftChild, transversal)
            transversal += (str(start.dataValue) + " - ")
            transversal = self._print_inorder(start.rightChild, transversal)
        return transversal

    def _print_postorder(self, start, transversal):
        #Left -> Right -> Root
        if start:
            transversal = self._print_postorder(start.leftChild, transversal)
            transversal = self._print_postorder(start.rightChild, transversal)
            transversal += (str(start.dataValue) + " - ")
        return transversal

    def search(self, value):
        if self.root!=None:
            return self._search(value,self.root)
        else:
            return False

    def _search(self, value, cur_node):
        if value==cur_node.dataValue:
            return True
        elif value < cur_node.dataValue and cur_node.leftChild!=None:
            return self._search(value, cur_node.leftChild)
        elif value > cur_node.dataValue and cur_node.rightChild!=None:
            return self._search(value, cur_node.rightChild)
        return False

    def min_value(self, node):
        current = node
        while(current.leftChild is not None):
            current = current.leftChild
        return current

    def delete_node(self, node, value):
        if node is None:
            return node
        if value < node.dataValue:
            node.leftChild = self.delete_node(node.leftChild, value)
        elif value > node.dataValue:
            node.rightChild = self.delete_node(node.rightChild, value)
        else:
            if node.leftChild is None:
                temp = node.rightChild
                node = None
                return temp
            elif node.rightChild is None:
                temp = node.leftChild
                node = None
                return temp
            temp = self.min_value(node)
            node.dataValue = temp.dataValue
            node.rightChild = self

    def getNumNodes(self):
        if self.root:
            return self._getNumNodes(self.root)
        else:
            return 0

    def _getNumNodes(self, node):
        total = 1
        if node.leftChild:
            total += self._getNumNodes(node.leftChild)
        if node.rightChild:
            total += self._getNumNodes(node.rightChild)
        return total

    def getHeight(self):
        return self._getHeight(self.root)

    def _getHeight(self, node):
        if not node:
            return 0
        else:
            return max(self._getHeight(node.leftChild), self._getHeight(node.rightChild)) + 1

    def fillTree(self, height):
        self._fillTree(self.root, height)

    def _fillTree(self, node, height):
        if height <= 1:
            return
        if node:
            if not node.leftChild: node.leftChild = Node(' ')
            if not node.rightChild: node.rightChild = Node(' ')
            self._fillTree(node.leftChild, height - 1)
            self._fillTree(node.rightChild, height - 1)

    def prettyPrint(self):
        """
        """
        # get height of tree
        total_layers = self.getHeight()

        tree = deepcopy(self)

        tree.fillTree(total_layers)
        # start a queue for BFS
        queue = Queue()
        # add root to queue
        queue.enqueue(tree)  # self = root
        # index for 'generation' or 'layer' of tree
        gen = 1
        # BFS main
        while not queue.isEmpty():
            # copy queue
            #
            copy = Queue()
            while not queue.isEmpty():
                copy.enqueue(queue.dequeue())
            #
            # end copy queue

            first_item_in_layer = True
            edges_string = ""
            extra_spaces_next_node = False

            # modified BFS, layer by layer (gen by gen)
            while not copy.isEmpty():

                root = copy.dequeue()

                # -----------------------------
                # init spacing
                spaces_front = pow(2, total_layers - gen + 1) - 2
                spaces_mid = pow(2, total_layers - gen + 2) - 2
                dash_count = pow(2, total_layers - gen) - 2
                if dash_count < 0:
                    dash_count = 0
                spaces_mid = spaces_mid - (dash_count * 2)
                spaces_front = spaces_front - dash_count
                init_padding = 2
                spaces_front += init_padding
                if first_item_in_layer:
                    edges_string += " " * init_padding
                # ----------------------------->

                # -----------------------------
                # construct edges layer
                edge_sym = "/" if root.leftChild and root.leftChild.data is not " " else " "
                if first_item_in_layer:
                    edges_string += " " * (pow(2, total_layers - gen) - 1) + edge_sym
                else:
                    edges_string += " " * (pow(2, total_layers - gen + 1) + 1) + edge_sym
                edge_sym = "\\" if self.root.rightChild and self.root.rightChild.data is not " " else " "
                edges_string += " " * (pow(2, total_layers - gen + 1) - 3) + edge_sym
                # ----------------------------->

                # -----------------------------
                # conditions for dashes
                if self.root.leftChild and self.root.leftChild.data == " ":
                    dash_left = " "
                else:
                    dash_left = "_"

                if self.root.rightChild and self.root.rightChild.data == " ":
                    dash_right = " "
                else:
                    dash_right = "_"
                # ----------------------------->

                # -----------------------------
                # handle condition for extra spaces when node lengths don't match or are even:
                if extra_spaces_next_node:
                    extra_spaces = 1
                    extra_spaces_next_node = False
                else:
                    extra_spaces = 0
                # ----------------------------->

                # -----------------------------
                # account for longer data
                data_length = len(str(self.root.data))
                if data_length > 1:
                    if data_length % 2 == 1:  # odd
                        if dash_count > 0:
                            dash_count -= ((data_length - 1) / 2)
                        else:
                            spaces_mid -= (data_length - 1) / 2
                            spaces_front -= (data_length - 1) / 2
                            if data_length is not 1:
                                extra_spaces_next_node = True
                    else:  # even
                        if dash_count > 0:
                            dash_count -= ((data_length) / 2) - 1
                            extra_spaces_next_node = True
                            # dash_count += 1
                        else:
                            spaces_mid -= (data_length - 1)
                            spaces_front -= (data_length - 1)
                # ----------------------------->

                # -----------------------------
                # print node with/without dashes
                if first_item_in_layer:
                    print(str(" " * spaces_front)   + str(dash_left * dash_count) + str(node.data) + str(dash_right * dash_count), end="")
                    first_item_in_layer = False
                else:
                    print((" " * (spaces_mid - extra_spaces)) + (dash_left * dash_count) + (self.root.data) + (dash_right * dash_count), end=""),
                # ----------------------------->

                if self.root.leftChild: queue.enqueue(self.root.leftChild)
                if self.root.rightChild: queue.enqueue(self.root.rightChild)

            # print the fun squiggly lines
            if not queue.isEmpty():
                print("\n" + edges_string)
            gen += 1

The error that I'm getting is 'binary_search_tree' object has no attribute 'leftChild' at this part of the code:

edge_sym = "/" if root.leftChild and root.leftChild.data is not " " else " "

I'm guessing that I have "conflicted" the classes, but I have no idea how to fix that.


Solution

  • Your code is incomplete and not really a minimal reproducible example. But it looks like

    Here you copy the binary_search_tree object

    tree = deepcopy(self)
    

    And then assume that tree is your root object aka a Node

    queue.enqueue(tree)  # self = root
    

    Which is not true because your root node would be tree.root.

    This solves at least your problem at the moment but there are many more issues like trying to print .data from a node when it's named .dataValue.