Search code examples
c++algorithmprimesexponentiation

overflow possibilities in modular exponentiation by squaring


I am looking to implement the fermat's little theorem for prime testing. Here's the code I have written:

lld expo(lld n, lld p) //2^p mod n
{
    if(p==0)
        return 1;
    lld exp=expo(n,p/2);
    if(p%2==0)
        return (exp*exp)%n;
    else
        return (((exp*exp)%n)*2)%n;
}

bool ifPseudoPrime(lld n)
{
    if(expo(n,n)==2)
        return true;
    else
        return false;
}

NOTE: I took the value of a(<=n-1) as 2.

Now, the number n can go as large as 10^18. This means that variable exp can reach values near 10^18. Which further implies that the expression (exp*exp) can reach as high as 10^36 hence causing overflow. How do I avoid this.

I tested this and it ran fine till 10^9. I am using C++


Solution

  • If the modulus is close to the limit of the largest integer type you can use, things get somewhat complicated. If you can't use a library that implements biginteger arithmetic, you can roll a modular multiplication yourself by splitting the factors in low-order and high-order parts.

    If the modulus m is so large that 2*(m-1) overflows, things get really fussy, but if 2*(m-1) doesn't overflow, it's bearable.

    Let us suppose you have and use a 64-bit unsigned integer type.

    You can calculate the modular product by splitting the factors into low and high 32 bits, the product then splits into

    a = a1 + (a2 << 32)    // 0 <= a1, a2 < (1 << 32)
    b = b1 + (b2 << 32)    // 0 <= b1, b2 < (1 << 32)
    a*b = a1*b1 + (a1*b2 << 32) + (a2*b1 << 32) + (a2*b2 << 64)
    

    To calculate a*b (mod m) with m <= (1 << 63), reduce each of the four products modulo m,

    p1 = (a1*b1) % m;
    p2 = (a1*b2) % m;
    p3 = (a2*b1) % m;
    p4 = (a2*b2) % m;
    

    and the simplest way to incorporate the shifts is

    for(i = 0; i < 32; ++i) {
        p2 *= 2;
        if (p2 >= m) p2 -= m;
    }
    

    the same for p3 and with 64 iterations for p4. Then

    s = p1+p2;
    if (s >= m) s -= m;
    s += p3;
    if (s >= m) s -= m;
    s += p4;
    if (s >= m) s -= m;
    return s;
    

    That way is not very fast, but for the few multiplications needed here, it may be fast enough. A small speedup should be obtained by reducing the number of shifts; first calculate (p4 << 32) % m,

    for(i = 0; i < 32; ++i) {
        p4 *= 2;
        if (p4 >= m) p4 -= m;
    }
    

    then all of p2, p3 and the current value of p4 need to be multiplied with 232 modulo m,

    p4 += p3;
    if (p4 >= m) p4 -= m;
    p4 += p2;
    if (p4 >= m) p4 -= m;
    for(i = 0; i < 32; ++i) {
        p4 *= 2;
        if (p4 >= m) p4 -= m;
    }
    s = p4+p1;
    if (s >= m) s -= m;
    return s;