I am looking to implement the fermat's little theorem for prime testing. Here's the code I have written:
lld expo(lld n, lld p) //2^p mod n
{
if(p==0)
return 1;
lld exp=expo(n,p/2);
if(p%2==0)
return (exp*exp)%n;
else
return (((exp*exp)%n)*2)%n;
}
bool ifPseudoPrime(lld n)
{
if(expo(n,n)==2)
return true;
else
return false;
}
NOTE: I took the value of a(<=n-1)
as 2
.
Now, the number n can go as large as 10^18
. This means that variable exp
can reach values near 10^18
. Which further implies that the expression (exp*exp)
can reach as high as 10^36
hence causing overflow. How do I avoid this.
I tested this and it ran fine till 10^9
. I am using C++
If the modulus is close to the limit of the largest integer type you can use, things get somewhat complicated. If you can't use a library that implements biginteger arithmetic, you can roll a modular multiplication yourself by splitting the factors in low-order and high-order parts.
If the modulus m
is so large that 2*(m-1)
overflows, things get really fussy, but if 2*(m-1)
doesn't overflow, it's bearable.
Let us suppose you have and use a 64-bit unsigned integer type.
You can calculate the modular product by splitting the factors into low and high 32 bits, the product then splits into
a = a1 + (a2 << 32) // 0 <= a1, a2 < (1 << 32)
b = b1 + (b2 << 32) // 0 <= b1, b2 < (1 << 32)
a*b = a1*b1 + (a1*b2 << 32) + (a2*b1 << 32) + (a2*b2 << 64)
To calculate a*b (mod m)
with m <= (1 << 63)
, reduce each of the four products modulo m
,
p1 = (a1*b1) % m;
p2 = (a1*b2) % m;
p3 = (a2*b1) % m;
p4 = (a2*b2) % m;
and the simplest way to incorporate the shifts is
for(i = 0; i < 32; ++i) {
p2 *= 2;
if (p2 >= m) p2 -= m;
}
the same for p3
and with 64 iterations for p4
. Then
s = p1+p2;
if (s >= m) s -= m;
s += p3;
if (s >= m) s -= m;
s += p4;
if (s >= m) s -= m;
return s;
That way is not very fast, but for the few multiplications needed here, it may be fast enough. A small speedup should be obtained by reducing the number of shifts; first calculate (p4 << 32) % m
,
for(i = 0; i < 32; ++i) {
p4 *= 2;
if (p4 >= m) p4 -= m;
}
then all of p2
, p3
and the current value of p4
need to be multiplied with 232 modulo m
,
p4 += p3;
if (p4 >= m) p4 -= m;
p4 += p2;
if (p4 >= m) p4 -= m;
for(i = 0; i < 32; ++i) {
p4 *= 2;
if (p4 >= m) p4 -= m;
}
s = p4+p1;
if (s >= m) s -= m;
return s;