Search code examples
treebinarybinary-treebinary-search-tree

This code is solving Tree Sum problem without executing a return statement in a recursive function?


I am trying to solve GeeksforGeeks problem Sum Tree:

Given a Binary Tree. Return true if, for every node X in the tree other than the leaves, its value is equal to the sum of its left subtree's value and its right subtree's value. Else return false.

An empty tree is also a Sum Tree as the sum of an empty tree can be considered to be 0. A leaf node is also considered a Sum Tree.

Example 1:

Input:

    3
  /   \    
 1     2

Output: 1

Explanation: The sum of left subtree and right subtree is 1 + 2 = 3, which is the value of the root node. Therefore,the given binary tree is a sum tree.

Example 2:

Input:

       10
     /    \
   20      30
  /  \ 
10    10

Output: 0

Explanation: The given tree is not a sum tree. For the root node, sum of elements in left subtree is 40 and sum of elements in right subtree is 30. Root element = 10 which is not equal to 30+40.

My code below passed all major test cases, such as these:

  1. Test-case 1 : 1
  2. Test-case 2 : 62 16 15 N 8 4 7 N 8 4
  3. Test-case 3 : 110 30 30 10 10 20 10

My question

How is it possible that my code is working without executing a return statement in the recursive call of the solve function? It looks like there a default return value for the recursive function solve I can rely on.

Code

Here is the driver code from GeeksforGeeks (not mine -- I just provide it for complete information):

#include <bits/stdc++.h>
using namespace std;

struct Node
{
    int data;
    struct Node *left;
    struct Node *right;
};
// Utility function to create a new Tree Node
Node* newNode(int val)
{
    Node* temp = new Node;
    temp->data = val;
    temp->left = NULL;
    temp->right = NULL;
    
    return temp;
}
// Function to Build Tree
Node* buildTree(string str)
{   
    // Corner Case
    if(str.length() == 0 || str[0] == 'N')
            return NULL;
    
    // Creating vector of strings from input 
    // string after spliting by space
    vector<string> ip;
    
    istringstream iss(str);
    for(string str; iss >> str; )
        ip.push_back(str);
        
    // Create the root of the tree
    Node* root = newNode(stoi(ip[0]));
        
    // Push the root to the queue
    queue<Node*> queue;
    queue.push(root);
        
    // Starting from the second element
    int i = 1;
    while(!queue.empty() && i < ip.size()) {
            
        // Get and remove the front of the queue
        Node* currNode = queue.front();
        queue.pop();
            
        // Get the current node's value from the string
        string currVal = ip[i];
            
        // If the left child is not null
        if(currVal != "N") {
                
            // Create the left child for the current node
            currNode->left = newNode(stoi(currVal));
                
            // Push it to the queue
            queue.push(currNode->left);
        }
            
        // For the right child
        i++;
        if(i >= ip.size())
            break;
        currVal = ip[i];
            
        // If the right child is not null
        if(currVal != "N") {
                
            // Create the right child for the current node
            currNode->right = newNode(stoi(currVal));
                
            // Push it to the queue
            queue.push(currNode->right);
        }
        i++;
    }
    
    return root;
}
// } Driver Code Ends

// Solution class comes here 
// ... see my code in next code block ...
// 

//{ Driver Code Starts.

int main()
{

    int t;
    scanf("%d ",&t);
    while(t--)
    {
        string s;
        getline(cin,s);
        Node* root = buildTree(s);
        Solution ob;
        cout <<ob.isSumTree(root) << endl;
    }
    return 1;
}
// } Driver Code Ends

My code:

class Solution
{
    public:
    // int returnSum=0;

    // Should return true if tree is Sum Tree, else false
    int solve(Node* root, int& sumofSubtree){
        if(root==NULL){
            return 0;
        }
        int l = solve(root->left, sumofSubtree);
        int r = solve(root->right, sumofSubtree);
        // int sum = l+r;
        sumofSubtree = l+r+root->data; 
        if(sumofSubtree-root->data == root->data){
            return sumofSubtree;
        }
    }
    
    bool isSumTree(Node* root)
    {  //Main Function
        int sumofSubtree = 0;
        solve(root, sumofSubtree);
        cout<<"sumofSubtree"<<" : "<<sumofSubtree<<endl;
        // 
        if(sumofSubtree-root->data==root->data || sumofSubtree == root->data){
            return true;
        }
        else{
            return false;
        }
    }
};

Solution

  • If non-void functions don't return a value, you have undefined behaviour. Although it may work for several runs and inputs, this is not reliable.

    Your code certainly does not pass for every input. For instance, I used this custom input test case in the GeeksforGeeks user interface:

    3 1 2 4
    

    ...and the test failed.

    The isSumTree function gives two possibilities for the function to return true:

    • sumofSubtree - root->data == root->data or
    • sumofSubtree == root->data

    This makes little sense, as obviously there is only one correct value for root-data, not two.

    Secondly, this is the only check that determines the boolean result, whether or not a deeper node has an incorrect value. This (wrongly) assumes that if there are deeper nodes with incorrect values, that sumofSubtree can never match the root data. This is a wrong assumption: we can imagine that two wrong values in deeper nodes may actually compensate each other in the final value of sumofSubtree which may accidentally match the root's data again, giving a false positive.

    Of course, these are considerations which are of little importance when we already have undefined behaviour.

    Some other remarks

    • Although sumofSubtree is passed by reference, the caller (except in isSumTree) never uses the value that may have been stored there after having made the recursive calls. It immediately overwrites it. So in essence, this by-reference parameter only serves for checking the root node. This is not really improving the readability of the code.

    • All of the tree is traversed, while for some inputs an inspection of the first few nodes could already indicate it is not correct. Take for instance this potentially large tree:

                  100
                /     \
              20       20
             /  \     /  \
            5    5   5    5
           / \  / \ / \  / \
         .. .. .. .. .. .. ..
      

    Without looking at any of the nodes in the deeper levels it is already clear that if those two 20 values correctly match the required sums and are not leaves, then the root should have value 80, not 100. We don't need to look at the deeper values to know this. The only thing that counts here is that these 20 nodes are not leaves and so if we temporarily assume that their subtrees sum up to 20, the two subtrees of the root both sum up to 40.

    We can have a similar reasoning for when one of the root's children is a leaf (then it doesn't count double), ...etc.

    Here is code that doesn't look deeper than necessary. Of course, if the tree happens to be correct as a whole, then all nodes need to be traversed:

    class Solution
    {
    private:
        bool isLeaf(Node *root) {
            return root && !root->left && !root->right;
        }
        
        int getSum(Node *root) {
            return !root ? 0 : isLeaf(root) ? root->data : 2 * root->data;
        }
    
    public:
        bool isSumTree(Node* root)
        {
            return !root 
                || isLeaf(root) 
                || root->data == getSum(root->left) + getSum(root->right)
                   && isSumTree(root->left) && isSumTree(root->right);
        }
    };