Search code examples
javatreebinary-search-treeavl-treedepth

Visualizing self-balancing binary tree (AVL-tree)


I have written an application for visualizing a binary search tree. Now I am attempting to modify it to visualize a self-balancing binary search tree. I have the balancing methods, but there seems to be some problems with updating the depth and index vars in each node, which are used to calculate the draw position of the nodes. The code is getting quite difficult to wrap my head around after trying a bunch of different things, but I suspect there is a simple solution, so I thought i'd ask here.

Example run: Input nodes 50, 60, 70. Tree should look like this:

60(depth = 0, index = 1, height = 1, size = 2), 
50(depth = 1, index = 1, height=0, size = 0) and 
70(depth = 1, index = 2, height = 0, size = 0)
    60
  50  70 

But instead, it looks like this:

50(depth = 0, index = 1, height = 0, size = 0)
60(depth = 1, index = 2, height = 2, size = 2)
70(depth = 2, index = 4, height = 0, size = 0)
50
  60
    70
public void setDrawPosition(Node node) {
        node.drawX = (node.index * DrawPanel.width) / ((int)Math.pow(2,node.depth) + 1);
        node.drawY = node.depth * DrawPanel.height / (depth+1);
    }
    
    public void addNode(int val) {
        root = addNode(val, root);
        System.out.println("tree depth is now: " + this.depth);
        root.height = this.depth;
    }

Here's the addNode-method:

private Node addNode(int val, Node node) { 
        if (node == null) {
            return new Node(val, 0, 0, 1, 0); // Node(int val, int depth, int size, int index, int height)
        }
        if (val < node.key) {
           node.left = addNode(val, node.left);
           node.left.depth = node.depth + 1;
           if(node.left.depth > this.depth) {
                depth = node.left.depth;
                root.height = node.left.depth;
            }
           node.left.parent = node;
           node.left.index = (node.index*2)-1;
           if (height(node.left) - height(node.right) == 2) {
               if (val < node.left.key)
                   node = rotateWithLeftChild(node);
               else
                   node = doubleWithLeftChild(node);
           }
           
        } else {
           node.right = addNode(val, node.right);
           node.right.depth = node.depth + 1;
           if(node.right.depth > this.depth) {
                depth = node.right.depth;
                root.height = node.right.depth;
            }
           node.right.parent = node;
           node.right.index = (node.index*2);
           if (height( node.right ) - height( node.left ) == 2) {
               if (val > node.right.key) 
                   node = rotateWithRightChild(node);              
               else 
                   node = doubleWithRightChild(node);
           }
        } 
        node.height = max( height( node.left ), height( node.right ) ) + 1;
        node.size++;
        return node;
    }

One of the simple rotations (they are similar):

   private Node rotateWithRightChild(Node k1) {
       Node k2 = k1.right;
       k1.right = k2.left;
       k2.left = k1;
       k1.height = max( height( k1.left ), height( k1.right ) ) + 1;
       k2.height = max( height( k2.right ), k1.height ) + 1;
       
       return k2;
   }

Solution

  • I would suggest to not store this extra information with the nodes, as it can be a pain to keep that information updated. Instead determine this information dynamically whenever you need to draw the tree.

    For instance, you could keep only val, left and right as properties of a Node instance, and define a recursive method to calculate the height of the current node. Then the actual drawing method could use that to get the overall height of the tree, and use a breadth first traversal to get all other needed information to draw the tree.

    Here is some code that does a simplified "draw": just an output line by line, but using appropriate indents. I think it should be simple to adapt to your drawing mechanism:

    import java.util.ArrayList;
    
    class Node {
        int val;
        Node left;
        Node right;
    
        Node(int val) {
            this.val = val;
            left = right = null;
        }
    
        Node add(int val) { 
            return val < this.val
               ? left != null ? left.add(val)
                              : (left = new Node(val))
               : right != null ? right.add(val)
                               : (right = new Node(val));
        }
    
        int getHeight() {
            return 1 + Math.max(
                left == null ? 0 : left.getHeight(),
                right == null ? 0 : right.getHeight()
            );
        }
    
        void draw() {
            int colWidth = 5;
            int height = getHeight();
            int colDistance = (int) Math.pow(2, height);
            ArrayList<Node> level = new ArrayList<Node>();
            level.add(this);
            while (colDistance > 0) {
                ArrayList<Node> nextLevel = new ArrayList<Node>();
                String line = "";
                int col = colDistance / 2 - 1;
                for (int i = 0; i < level.size(); i++) {        
                    Node node = level.get(i);
                    if (node == null) {
                        nextLevel.add(null);
                        nextLevel.add(null);
                    } else {
                        if (col > 0) { // pad string
                            line = String.format("%-" + (col*colWidth) + "s", line);
                        }
                        line += Integer.toString(node.val);
                        nextLevel.add(node.left);
                        nextLevel.add(node.right);
                    }
                    col += colDistance;
                }
                System.out.println(line);
                level = nextLevel;
                colDistance /= 2;
            }
        }
    }
    

    Demo use:

        Node root = new Node(40);
        root.add(50);
        root.add(30);
        root.add(20);
        root.add(60);
        root.add(35);
        root.draw();