Search code examples
cbinary-search-tree

How to delete node from Binary Search Tree when it has 2 children


I am creating my first Tree ever. It seems like I am stuck on how to remove a node from the tree when it has 2 children. At the moment I am trying to:

  • find the inorder successor (leftmost child of right child)
  • replace the node to be deleted with it
  • then point the successor to the right child that the previous node was pointing to

But when I try to print the tree after deleting a node, it gets stuck in a loop, and eventually causes a stack overflow.

void _rm_node(Node **n, int i)
{
    if (*n == NULL)
    {
        return;
    }

    if (i < (*n)->index)
    {
        _rm_node(&(*n)->left, i);
    }
    else if (i > (*n)->index)
    {
        _rm_node(&(*n)->right, i);
    }
    else
    {
        // no children
        if ((*n)->left == NULL && (*n)->right == NULL)
        {
            free(*n);
            *n = NULL;
        }
        //only left child
        else if ((*n)->left && (*n)->right == NULL)
        {
            Node *tmp = *n;
            *n = (*n)->left;
            free(tmp);
        }
        //only right child
        else if ((*n)->left == NULL && (*n)->right)
        {
            Node *tmp = *n;
            *n = (*n)->right;
            free(tmp);
        }
        else
        //2 children, replace n with inorder successor
        //right tree's leftmost/lowest index
        {
            Node *tmp = *n;
            Node *successor = (*n)->right;
            while (successor->left)
            {
                successor = successor->left;
            }

            *n = successor;
            (*n)->left = tmp->left;
            if (tmp->right->index != successor->index)
            {
                (*n)->right = tmp->right;
            }
            else
            {
                (*n)->right = NULL;
            }
            free(tmp);
        }
    }
}
void PrintTree(Node *t)
{
    if (t == NULL)
    {
        return;
    }
    printf("tree %d\n", t->index);
    PrintTree(t->left);
    PrintTree(t->right);
}

Solution

    1. You need update the parent prev of the successor to successor->right. There are two cases:

      (a) The parent prev of successor is the node that we are overwriting so we need to update it's right pointer. Your code does this correctly with *n = successor; but (*n)->right = NULL; is essentially a no-op.

      (b) Otherwise we need to update prev's left pointer prev->left = successor->right which your code never does. (*n)->right = tmp->right; reverts the right pointer that *n = successor; incorrectly changed.

    2. *n = successor is a little heavy handed as you only need to overwrite the index, and in case (a) the right pointer.

    3. You use index values to determine if you are in case (a) and (b) which is ok if index values are unique. Changed it to use the address of node instead which would work with duplicate index values, too.

    #include <stdio.h>
    #include <stdlib.h>
    
    typedef struct Node {
        struct Node *left;
        struct Node *right;
        int index;
    } Node;
    
    Node *BuildTree(size_t n, int indices[n]);
    void RemoveNode(Node **n, int i);
    
    void PrintTree(Node *n) {
        if(!n) return;
        printf("tree %d\n", n->index);
        PrintTree(n->left);
        PrintTree(n->right);
    }
    
    void RemoveNode(Node **n, int i) {
        if (*n == NULL)
            return;
        if (i < (*n)->index) {
            RemoveNode(&(*n)->left, i);
            return;
        }
        if (i > (*n)->index) {
            RemoveNode(&(*n)->right, i);
        } else {
            // no children
            if ((*n)->left == NULL && (*n)->right == NULL)
            {
                free(*n);
                *n = NULL;
            }
            //only left child
            else if ((*n)->left && (*n)->right == NULL)
            {
                Node *tmp = *n;
                *n = (*n)->left;
                free(tmp);
            }
            //only right child
            else if ((*n)->left == NULL && (*n)->right)
            {
                Node *tmp = *n;
                *n = (*n)->right;
                free(tmp);
            }
            else
            {
                //2 children, replace n with inorder successor
                //right tree's leftmost/lowest index
                Node *prev = *n;
                Node *successor = (*n)->right;
                for(; successor->left; prev = successor, successor = successor->left);
                (*n)->index = successor->index;
                if(prev == *n)
                    prev->right = successor->right;
                else
                    prev->left = successor->right;
                free(successor);
            }
        }
    }
    
    Node *InsertNode(Node *root, int index) {
        Node *n = calloc(1, sizeof *n);
        n->index = index;
        Node **p = &root;
        while(*p) {
            if(index <= (*p)->index)
                p = &(*p)->left;
            else
                p = &(*p)->right;
        }
        *p = n;
        return root;
    }
    
    Node *BuildTree(size_t n, int indices[n]) {
        if(!n) return NULL;
        Node *t = NULL;
        for(size_t i = 0; i < n; i++)
            t = InsertNode(t, indices[i]);
        return t;
    }
    
    void test(size_t n, int *indices) {
        Node *t = BuildTree(n, indices);
        printf("Before:\n\n");
        PrintTree(t);
        RemoveNode(&t, 6);
        printf("\nAfter:\n\n");
        PrintTree(t);
    }
    
    int main(void) {
        test(5, (int []) {10, 6, 2, 7, 9}); // a
        putchar('\n');
        test(6, (int []) {10, 6, 2, 8, 7, 9}); // b
    }
    

    and matching output:

    Before:
    
    tree 10
    tree 6
    tree 2
    tree 7
    tree 9
    
    After:
    
    tree 10
    tree 7
    tree 2
    tree 9
    
    Before:
    
    tree 10
    tree 6
    tree 2
    tree 8
    tree 7
    tree 9
    
    After:
    
    tree 10
    tree 7
    tree 2
    tree 8
    tree 9