I am learning Number theory. Now, I want to write a program that perform Fermat primality test.
First, I write a modular square algorithm:
#modular_square.py
def modular_square(a, n, m):
res = 1
exp = n
b = a
while exp !=0 :
if exp % 2 == 1:
res *= b
res %= m
b *= b
exp >>= 1
return res
def main():
a = [ 12996, 312, 501, 468, 163]
n = [ 227, 13, 13, 237, 237]
m = [ 37909, 667, 667, 667, 667]
res = [ 7775, 468, 163, 312, 501]
#test modular_square()
print("===test modular_square()===")
for i, r in enumerate(res):
if modular_square(a[i], n[i], m[i]) != r:
print("modular_square() failed...")
else:
print("modular_square({},{},{})={}".format(a[i], n[i], m[i], r))
if __name__ == "__main__":
main()
Then, I write Fermat primality test algorithm based on the above algorithm.
#prime_test_fermat.py
import modular_square
import random
def Fermat_base(b, n):
res = modular_square.modular_square(b, n-1, n)
if res == 1:
return True
else:
return False
def Fermat_test(n, times):
for i in range(times):
b = random.randint(2, n-1)
if Fermat_base(b, n) == False:
return False
return True
def main():
b = [8, 2]
n = [63, 63]
res = [True, False]
#test Fermat_base()
print("===test Fermat_base()===")
for i,r in enumerate(res):
if Fermat_base(b[i], n[i]) != res[i]:
print("Fermat_base() failed...")
else:
print("Fermat_base({},{})={}".format(b[i], n[i], res[i]))
n = [923861,
1056420454404911
]
times = [2, 2]
res = [True,True ]
#test Fermat_test()
print("==test Fermat_test()===")
for i,r in enumerate(res):
if Fermat_test(n[i], times[i]) != res[i]:
print("Fermat_test() failed...")
else:
print("Fermat_test({},{})={}".format(n[i], times[i], res[i]))
if __name__ == '__main__':
main()
When I run prime_test_fermat.py
program, it didn't stop. This is caused by Fermat primality or my code that exists bug.
The problem is with your modular exponentiation algorithm: The modulus is applied to res
but not to b
. Since b
is squared in every iteration, it will become extremely large (as in several thousand digits). This slows down your algorithm.
To solve this, you have to apply the modulus to b
as well. Replace b *= b
with:
b *= b
b %= m
As an additional optimization, you can also apply the modulus when you initialize b
, by replacing b = a
with:
b = a
b %= m
You can take this pseudo-code as reference.