Search code examples
cmultiplicationexponentialunsigned-long-long-int

Unsigned Long Int overflow when calculating pow


I am trying to make a function that quickly calculates x^y mod z. It works well when calculating something like 2^63 mod 3, but at 2^64 mod 3 and higher exponents it just returns 0.

I am suspecting an overflow somewhere, but I can't pin it down. I have tried explicit casts at the places where calculations (* and mod) are made, I have also made my storage variables (resPow, curPow) unsigned long long int (as Suggested here) but that didn't help much.

typedef unsigned long int lint;

lint fastpow(lint nBase, lint nExp, lint nMod) { 
    int lastTrueBit = 0;
    unsigned long long int resPow = 1ULL;

    unsigned long long int curPow = nBase;
    for (int i = 0; i < 32; i++) {
        int currentBit = getBit(nExp, i);

        if (currentBit == 1) {
            for (lint j = 0; j < i - lastTrueBit; j++) {
                curPow = curPow * curPow;
            }
            resPow =resPow * curPow;
            lastTrueBit = i;
        }
    }
    return resPow % nMod;
}

Solution

  • I am suspecting an overflow somewhere,

    Yes, both curPow * curPow and resPow * curPow may mathematically overflow.

    The usual way to contain overflow here is to perform mod on intermediate products.

            // curPow = curPow * curPow;
            curPow = (curPow * curPow) % nMod;
        // resPow =resPow * curPow;
        resPow = (resPow * curPow) % nMod;
    

    This is sufficient when nMod < ULLONG_MAX/(nMod - 1). (The mod value is half the precision of unsigned long long). Otherwise more extreme measures are needed as in: Modular exponentiation without range restriction.


    Minor stuff

    for(int i = 0; i < 32; i++) assumes lint/unsigned long is 32 bits. Portable code would avoid that magic number. unsigned long is 64-bits on various platforms.

    LL is not needed here. U remains useful to quiet various compiler warnings.

    // unsigned long long int resPow = 1ULL;
    unsigned long long int resPow = 1U;