Search code examples
javac++algorithmbinary-search

Different outputs for Java and CPP


I was solving Nth root of M problem and I solved it with Java and here is the solution:

    public int NthRoot(int n, int m)
    {
        // code here
        int ans = -1;
        if (m == 0)
            return ans;
        if (n == 1 || n == 0)
            return m;
        if (m == 1)
            return 1;
        int low = 1;
        int high = m;
        while (low < high) {
            
            int mid = (low + high) / 2;
            int pow = (int)Math.pow(mid, n);
            if (pow == m) {
                
                ans = mid;
                break;
            }
            if (pow > m) {
                
                high = mid;
            } else {
                
                low = mid + 1;
            }
        }
        
        return ans;
    }

It passed all the test cases. But, when I solved it using C++, some test cases didn't pass. Here is the C++ solution:

    int NthRoot(int n, int m)
    {
        // Code here.
        int ans = -1;
        if (m == 0)
            return ans;
        if (n == 1 || n == 0)
            return m;
        if (m == 1)
            return 1;
        int low = 1;
        int high = m;
        while (low < high) {
            
            int mid = (low + high) / 2;
            int po = (int)pow(mid, n);
            if (po == m) {
                
                ans = (int)mid;
                break;
            }
            if (po > m) {
                
                high = mid;
            } else {
                
                low = mid + 1;
            }
        }
        
        return ans;
    } 

One of the test cases it didn't pass is:

6 4096
Java's output is 4 (Expected result)
C++'s output is -1 (Incorrect result)

When I traced it using paper and pen, I got a solution the same as Java's.

But, when I used long long int in the C++ code, it worked fine – but the size of Int/int in both Java and C++ are the same, right? (When I print INT_MAX and Integer.MAX_VALUE in C++ and Java, it outputs the same value.)


Solution

  • As you have probably guessed, the problem is due to the attempt to convert a double value to an int value, when that source is larger than the maximum representable value of an int. More specifically, it relates to the difference between how Java and C++ handle the cast near the start of your while loop: int po = (int)pow(mid, n);.

    For your example input (6 4096), the value returned by the pow function on the first run through that loop is 7.3787e+19, which is too big for an int value. In Java, when you attempt to cast a too-big value to an integer, the result is the maximum value representable by the integer type, as specified in this answer (bolding mine):

    • The value must be too large (a positive value of large magnitude or positive infinity), and the result of the first step is the largest representable value of type int or long.

    However, in C++, when the source value exceeds INT_MAX, the behaviour is undefined (according to this C++11 Draft Standard):

    7.10 Floating-integral conversions      [conv.fpint]

    1    A prvalue of a floating-point type can be converted to a prvalue of an integer type. The conversion truncates; that is, the fractional part is discarded. The behavior is undefined if the truncated value cannot be represented in the destination type.

    However, although formally undefined, many/most compilers/platforms will apply 'rollover' when this occurs, resulting in a very large negative value (typically, INT_MIN) for the result. This is what MSVC in Visual Studio does, giving a value of -2147483648, thus causing the else block to run … and keep running until the while loop reaches its terminating condition – at which point ans will never have been assigned any value except the initial -1.

    You can fix the problem readily by checking the double returned by the pow call and setting po to INT_MAX, if appropriate, to emulate the Java bevahiour:

        while (low < high) {
            int mid = (low + high) / 2;
            double fpo = pow(mid, n);
            int po = (int)(fpo);
            if (fpo > INT_MAX) po = INT_MAX; // Emulate Java for overflow
            if (po == m) {
                ans = (int)mid;
                break;
            }
            if (po > m) {
                high = mid;
            }
            else {
                low = mid + 1;
            }
        }