Search code examples
javadata-structuresavl-tree

Having an issue with my AVL tree code (Implementation)


I have been recently learning about trees and came across AVL trees. I referred Kunal Kushwaha from YouTube in learning these concepts. I have implemented the AVL tree in the code below i am pasting my full AVL java class

The problem is when I try to add 1000 elements(values) into the tree via the Main method, the height is supposed to be 3 because the height is log(n) -> log (1000) -> 3 but I am getting 12 as output instead of 3 but in Kunal's video the code he implemented produced 3 whereas me who nearly implemented the same code gets a different output. I have spent 2hours in figuring out where is the mistake made but couldn't find any mistake. I even referred ChatGPT but i was useless. Your help in finding out the mistake would be much appreciated.

I am posting 2 of the codes below

  1. My code:
public class AVL_Trees {
    //i dont know why code when i insert 1000 elements
    //expected height is 3 but i get 12 i tried via ChatGPT and cross checked my self but still can't find the solution on where the error is
    //post on reddit or stackoverflow
    private class Node{
        private int value;
        private Node left;
        private Node right;
        private int height;
        public Node(int value){
            this.value = value;
        }

        public int getValue() {
            return value;
        }

    }

    private Node root;
    public AVL_Trees(){} //constructor
    public int height() {
        return height(root);
    }
    public int height(Node node) { //helper function
        if (node == null) {
            return -1;
        }
        return node.height;
    }

    public boolean isEmpty(Node node){
        return node == null;
    }

    public void prettyDisplay(){
        prettyDisplay(root,0);
        System.out.println("Height "+height(root));
    }
    private void prettyDisplay(Node node,int level){
        if(node == null){
            return;
        }
        prettyDisplay(node.right,level+1);
        if(level!=0){
            for(int i=0;i<level;i++){
                System.out.print("|\t\t");
            }
            System.out.println("|--->" + node.value);
        }
        else {
            System.out.println("|--->" + node.value);
        }

        prettyDisplay(node.left,level+1);
    }

    public void insert(int value){
        root = insert(root,value);
    }
    private Node insert(Node node, int value){
        if(node==null){
            return new Node(value);
        }
        if(value < node.value){
            node.left = insert(node.left,value);
        }
        if(value > node.value){
            node.right = insert(node.right,value);
        }
        node.height = Math.max(height(node.left),height(node.right)) + 1;
        return rotate(node);
    }

    //for arrays as input
    public void populate(int[] nums){
        for(int i : nums){
            insert(i);
        }
    }


    public void populateSorted(int[] nums){ //incase if input array is sorted
        populateSorted(nums,0, nums.length);
    }
    private void populateSorted(int[] nums,int start,int end){
        if(start>=end){
            return;
        }
        int mid = start + (end - start)/2;
        insert(nums[mid]);
        populateSorted(nums,0,mid);
        populateSorted(nums,mid+1,end);
    }

    public boolean balanced(){
        return balanced(this.root);
    }
    private boolean balanced(Node node){
        if(node == null){
            return true;
        }
        // condition for balance is both child node's height difference 
        //should be less than or equal to 1
        return Math.abs(height(node.left) - height(node.right)) <=1 && balanced(node.left) && balanced(node.right);
    }

    public void display(){
        display(root,"The root Node is ");
    }
    private void display(Node node , String details){ //just a way of representing
        if(node == null){
            return;
        }
        System.out.println(details + node.value);
        display(node.left,"The left node of "+node.value+" is: ");
        display(node.right,"The right node of "+node.value+" is: ");
    }

    private Node rotate(Node node){
        if(height(node.left) - height(node.right) > 1){
            //left heavy
            if(height(node.left.left) - height(node.left.right) > 0 ){ 
              //put 0, not 1 or -1 think about it!
                //left-left case
                return rightRotate(node);
            }
            if(height(node.left.left) - height(node.left.right) < 0){
                //left-right case
                node.left =  leftRotate(node.left);
                return rightRotate(node);
            }
        }

        if(height(node.left) - height(node.right) < -1){
            //right heavy
            if(height(node.right.left) - height(node.right.right) < 0){
                //right-right case
                return leftRotate(node);
            }
            if(height(node.right.left) - height(node.right.right) > 0){
                //right-left case
                node.right = rightRotate(node.right);
                return leftRotate(node);
            }
        }
        return node;
    }

    private Node rightRotate(Node p){
        Node c = p.left;
        Node t = c.right;
        c.right = p;
        p.left = t;
        //updating heights
        p.height = Math.max(height(p.left),height(p.right))+1;
        c.height = Math.max(height(c.left),height(c.right))+1;
        return c;
    }

    private Node leftRotate(Node c){
        Node p = c.right;
        Node t = p.left;
        p.left = c;
        c.right = t;
        //updating heights
        p.height = Math.max(height(p.left),height(p.right))+1;
        c.height = Math.max(height(c.left),height(c.right))+1;
        return p;
    }
}

2.Kunal's correct code:

//this is the correct code for AVL tree where it can balance and height is 3 when adding 1000 elements
public class KunalAVL {
        public class Node {
            private int value;
            private Node left;
            private Node right;
            private int height;
            public Node(int value) {
                this.value = value;
            }

            public int getValue() {
                return value;
            }
        }

        private Node root;
        public KunalAVL() {

        }

        public int height() {
            return height(root);
        }
        private int height(Node node) {
            if (node == null) {
                return -1;
            }
            return node.height;
        }

        public void insert(int value) {
            root = insert(value, root);
        }

        private Node insert(int value, Node node) {
            if (node == null) {
                node = new Node(value);
                return node;
            }

            if (value < node.value) {
                node.left = insert(value, node.left);
            }

            if (value > node.value) {
                node.right = insert(value, node.right);
            }

            node.height = Math.max(height(node.left), height(node.right)) + 1;
            return rotate(node);
        }

        private Node rotate(Node node) {
            if (height(node.left) - height(node.right) > 1) {
                // left heavy
                if(height(node.left.left) - height(node.left.right) > 0) {
                    // left left case
                    return rightRotate(node);
                }
                if(height(node.left.left) - height(node.left.right) < 0) {
                    // left right case
                    node.left = leftRotate(node.left);
                    return rightRotate(node);
                }
            }

            if (height(node.left) - height(node.right) < -1) {
                // right heavy
                if(height(node.right.left) - height(node.right.right) < 0) {
                    // right right case
                    return leftRotate(node);
                }
                if(height(node.right.left) - height(node.right.right) > 0) {
                    // left right case
                    node.right = rightRotate(node.right);
                    return leftRotate(node);
                }
            }

            return node;
        }

        public Node rightRotate(Node p) {
            Node c = p.left;
            Node t = c.right;
            c.right = p;
            p.left = t;
            p.height = Math.max(height(p.left), height(p.right) + 1);
            c.height = Math.max(height(c.left), height(c.right) + 1);
            return c;
        }

        public Node leftRotate(Node c) {
            Node p = c.right;
            Node t = p.left;
            p.left = c;
            c.right = t;
            p.height = Math.max(height(p.left), height(p.right) + 1);
            c.height = Math.max(height(c.left), height(c.right) + 1);
            return p;
        }

        public void populate(int[] nums) {
            for (int i = 0; i < nums.length; i++) {
                this.insert(nums[i]);
            }
        }

        public void populatedSorted(int[] nums) {
            populatedSorted(nums, 0, nums.length);
        }

        private void populatedSorted(int[] nums, int start, int end) {
            if (start >= end) {
                return;
            }

            int mid = (start + end) / 2;
            this.insert(nums[mid]);
            populatedSorted(nums, start, mid);
            populatedSorted(nums, mid + 1, end);
        }

        public void display() {
            display(this.root, "Root Node: ");
        }

        private void display(Node node, String details) {
            if (node == null) {
                return;
            }
            System.out.println(details + node.value);
            display(node.left, "Left child of " + node.value + " : ");
            display(node.right, "Right child of " + node.value + " : ");
        }

        public boolean isEmpty() {
            return root == null;
        }

        public boolean balanced() {
            return balanced(root);
        }

        private boolean balanced(Node node) {
            if (node == null) {
                return true;
            }
            return Math.abs(height(node.left) - height(node.right)) <= 1 && balanced(node.left) && balanced(node.right);
        }


}

Solution

  • when I try to add 1000 elements(values) into the tree via the Main method, the height is supposed to be 3

    Note that in a binary tree (which AVL is) of height 3, you can only fit at the most 15 elements.

    For height 1, you can store at the most 3 nodes. For height 2, that is 7 nodes, and for height 3, that maximum is 15:

    height=1      height=2            height=3 
    
      O               O                ____O____
     / \            /   \             /         \
    O   O          O     O           O           O
                  / \   / \        /   \       /   \
                 O   O O   O      O     O     O     O
                                 / \   / \   / \   / \
                                O   O O   O O   O O   O
    

    Evidently there is no way you could store 1000 elements in a binary tree of height 3.

    ...because the height is log(n) -> log (1000) -> 3 but I am getting 12 as output instead of 3

    You seem to have misinterpreted what was intended with log𝑛. It is not log10𝑛, but log2𝑛. More precisely, the maximum height an AVL tree of 𝑛 nodes can have is bounded by 1.44 log2𝑛.

    If you have an AVL tree with 1000 nodes it could have a height up to 13.

    but in Kunal's video the code he implemented produced 3

    That is because the code you presented as Kunal's code has bugs. Indeed it outputs 3, but as shown above that couldn't possibly be correct.

    The problems in the code

    The second version you provided has a bug in the rotation functions leftRotate and rightRotate. It miscalculates the new height, because it adds the 1 only to the second argument of max, like here:

    p.height = Math.max(height(p.left), height(p.right) + 1);
    //                                                 ^^^^^^ 
    

    ...while that 1 should be added to the result of the max call. In the first version you presented, this is correctly done (good for you!).

    There is another bug in leftRotate, in both versions of the code. leftRotate recalibrates the heights of the nodes p and c in the wrong order. You should first do it for the deepest of the two nodes, which in leftRotate is the node c -- as it is the child of p.

    So fix it to:

        private Node leftRotate(Node c){
            Node p = c.right;
            Node t = p.left;
            p.left = c;
            c.right = t;
            //updating heights: *** but first C!!!!
            c.height = Math.max(height(c.left),height(c.right))+1;
            p.height = Math.max(height(p.left),height(p.right))+1;
            return p;
        }
    

    NB: In rightRotate you did it in the correct order.

    The results

    After having made the above corrections, I inserted the numbers 1 to 1000 in sorted order into the tree, and a height of 9 is reported, which is perfectly fine. It means the tree is as balanced as it could be. Inserting those numbers in a random order gave me heights of mostly 11... which is as expected.