Search code examples
c++recursiondata-structuresreferencepass-by-reference

Understanding reference arguments in recursion


I have tried implementing the code for Sorted Linked List to BST from leetcode https://leetcode.com/problems/convert-sorted-list-to-binary-search-tree/ .

The approach which I have used to use a recursive approach like below. But here we have to use reference (&) in head in the function argument. If I don't use, then output will not be correct and gives wrong answer. I am not able to grasp why reference to head is needed here or in what type of recursion scenarios we should do that. I sometimes get confused in recursion.

int countNodes(ListNode* head) {
    int count = 0;
    ListNode *temp = head;
    
    while (temp != NULL) {
        temp = temp->next;
        count++;
    }
    return count;
}

TreeNode *sortedListUtil(ListNode *&head, int n) { 
    // head should be ref, if we dont use & , 
    //   unexpected output will come
    if (n <= 0)
        return NULL;
    
    TreeNode *left = sortedListUtil(head, n/2);
    
    TreeNode *root = new TreeNode(head->val);
    
    root->left = left;
    head = head->next;
    
    root->right = sortedListUtil(head, n - n/2 -1); 
        // recur for remaining nodes
    
    return root;
}

TreeNode* sortedListToBST(ListNode* head) {
    if (head == NULL)
        return NULL;
    
    int n = countNodes(head);
    return sortedListUtil(head, n);
}

Input linked List is : head = [-10,-3,0,5,9]
Output of BST should be this if I use & in head in function argument : [0,-3,9,-10,null,5]

If I don't use '&' in function argument then the tree constructed will be :

[-10,-10,-3,-10,null,-3] which is wrong. Root Node should be 0.


Solution

  • sortedListUtil does two things: it returns the root (1), and it also changes the head (2) so that its invocations are advancing along the list from call to the other recursive call:

    TreeNode *sortedListUtil(ListNode *&head, int n) { 
        // head should be ref, if we dont use & , 
        //   unexpected output will come
        if (n <= 0)
            return NULL;
        
        TreeNode *left = sortedListUtil(head, n/2);
        
        TreeNode *root = new TreeNode(head->val);  // <<<<------ NB (3)
        
        root->left = left;
        head = head->next;  // <<<<--------------------- NB (2)
        
        root->right = sortedListUtil(head, n - n/2 -1); 
            // recur for remaining nodes
        
        return root;        // <<<<--------------------- NB (1)
    }
    

    Without changing the head, each invocation of sortedListUtil would look at the same element in the input linked list -- its head element.

    But this way, for each element that is put into a ListNode by TreeNode *root = new TreeNode(head->val); (3), the head is advanced so that the next element from the list will be the one that gets put into the next constructed ListNode.

    Since head is function's parameter, it must be passed by reference & so the change is seen by the caller; otherwise the variable would be local to the function invocation and the change would not be seen by the caller.

    Since we've used the list's element, we must advance the pointer into the list, so that the next list's element is the one that goes next into the next node.


    edit: Where does a given sortedListUtil invocation change head? How many times? Just once! So whatever is done within the "left" invocation(s), the head is advanced; then we take one element (the one we're at, after the left is filled), advance head one notch accordingly, and let the "right" invocation(s) fill the right subtree!

    Recursion works like this: assume it works; conclude it works for the smaller case(s) (here, left and right) by the hypothesis (and also, for the smallest case -- where we just return NULL and do not touch the head); see that adding small change by this invocation doesn't break things (it doesn't, we properly advance head one notch); then conclude it works overall!