Search code examples
pythonperformanceoptimizationprimes

Optimizing The Exact Prime Number Theorem


For example, given this sequence of the first 499 primes, can you predict the next prime?

2,3,5,7,...,3541,3547,3557,3559 

The 500th prime is 3571.


Prime Number Theorem

The Prime Number Theorem (PNT) provides an approximation for the n-th prime:

approx

Computing p_500 ≈ 3107 takes microseconds!


Exact Prime Number Theorem

My experimental Exact Prime Number Theorem (EPNT) computes the exact n-th prime:

fast

Computing p_500 = 3571 takes 25 minutes!


Question

So far, the EPNT correctly predicts the first 500 primes.

Unfortunately, numerically verifying the formula for higher primes is extremely slow!

Are there any optimization tips to improve the EPNT computational speed? Perhaps

  • Do not use Python
  • Add multiple threads
  • Implement a faster math precision library
  • Modify the decimal precision mp.dps at runtime
  • Use a math computing engine like WolframAlpha

Here's the current Python code:

import time
from mpmath import ceil, ln, mp, mpf, exp, fsum, power, zeta
from sympy import symbols, Eq, pprint, prime

N=500   # <--- Compute the N-th prime.
    
mp.dps = 20000
primes = []

def vengy_prime(k):
    # Compute the k-th prime deterministically
    s = k * ln(k * ln(k))
    
    # Determine the dynamic Rosser (1941) upper bound
    N = int(ceil(k * (ln(k) + ln(ln(k)))))
    
    # Compute finite summation to N
    print(f"Computing {N} zeta terms ...")   
    start_time = time.time()
    sum_N = fsum([1 / power(mpf(n), s) for n in range(1, N)])
    end_time = time.time()      
    print(f"Time taken: {end_time - start_time:.6f} seconds")
    
    # Compute the product term involving the first k-1 primes
    print(f"Computing product of {k-1} previous primes ...")  
    start_time = time.time()
    prod = exp(fsum([ln(1 - power(p, -s)) for p in primes[:k-1]]))
    end_time = time.time()      
    print(f"Time taken: {end_time - start_time:.6f} seconds")
    
    # Compute next prime p_k
    p_k=ceil((1 - 1 / (sum_N * prod)) ** (-1 / s))
    return p_k


# Generate the previous known k-1 primes
print("\nListing", N-1, "known primes:")
for k in range(1, N):
    p = prime(k)
    primes.append(p)        
print(primes)        

primes.append(vengy_prime(N))
pprint(Eq(symbols(f'p_{N}'), int(primes[-1])))

Update

Wow! Running Jérôme Richard's new optimized code only took 10 seconds!

Computing 4021 zeta terms ...
Time taken: 7.968423 seconds
Computing product of 499 previous primes ...
Time taken: 1.960771 seconds
p₅₀₀ = 3571

The old code timings were 1486 seconds:

Computing 4021 zeta terms ...
Time taken: 1173.899538 seconds
Computing product of 499 previous primes ...
Time taken: 313.833039 seconds
p₅₀₀ = 3571

The optimized code computed the 4000th prime in 45 minutes: N = 4000, precision = 700000

p_4000


Solution

  • TL;DR: the code can be optimized with gmpy2 and accelerated with multiple threads, but the main issue is that this formula is simply a very inefficient way of finding the next prime number.

    Implement a faster math precision library

    mpmath is indeed a bit slow. You can just use gmpy2 instead! It is a bit faster. gmpy2 is one of the fastest library I am aware of (for large numbers). Note that the very last digits of the two modules can differ (due to rounding and the accuracy of the math functions).

    Do not use Python

    Native languages will not make this code significantly faster. Indeed, with gmpy2, most of the time should be clearly spent in the GMP library written in C and highly optimized. Thus, Python is fine here.

    Add multiple threads

    Indeed, we can easily spawn CPython processes here since the operation is compute bound. This can easily be done with joblib. However, there is a catch: we need to reset the mpmath/gmpy2 context for each process (i.e. the precision) so to compute the numbers correctly.


    Here is the code applying all optimizations so far:

    import time
    import gmpy2 as gmp
    from gmpy2 import log as ln, ceil, mpfr, exp
    from sympy import symbols, Eq, pprint, prime
    from joblib import Parallel, delayed
    
    N = 500   # <--- Compute the N-th prime.
    precision = 66432 # 20_000 decimals ~= 66432 bits
    gmp.get_context().precision = precision
    primes = []
    
    # TODO: use a pair-wise sum for better precision or even a Kahan sum!
    fsum = sum
    
    def vengy_prime(k):
        # Compute the k-th prime deterministically
        s = k * ln(k * ln(k))
        
        # Determine the dynamic Rosser (1941) upper bound
        N = int(ceil(k * (ln(k) + ln(ln(k)))))
    
        ms = -s
        parallel = Parallel(-1)
        
        # Compute finite summation to N
        print(f"Computing {N} zeta terms ...")   
        start_time = time.time()
        def compute(n):
            gmp.get_context().precision = precision
            return n ** ms
        lst = parallel(delayed(compute)(n) for n in range(1, N))
        sum_N = fsum(lst)
        #sum_N = fsum([n**ms for n in range(1, N)])
        end_time = time.time()      
        print(f"Time taken: {end_time - start_time:.6f} seconds")
        
        # Compute the product term involving the first k-1 primes
        print(f"Computing product of {k-1} previous primes ...")  
        start_time = time.time()
        def compute(p):
            gmp.get_context().precision = precision
            return ln(1-p**ms)
        lst = parallel(delayed(compute)(p) for p in primes[:k-1])
        prod = exp(fsum(lst))
        #prod = exp(fsum([ln(1 - p**ms) for p in primes[:k-1]]))
        end_time = time.time()      
        print(f"Time taken: {end_time - start_time:.6f} seconds")
        
        # Compute next prime p_k
        p_k=ceil((1 - 1 / (sum_N * prod)) ** (-1 / s))
        return p_k
    
    # Generate the previous known k-1 primes
    print("\nListing", N-1, "known primes:")
    for k in range(1, N):
        p = prime(k)
        primes.append(p)        
    print(primes)        
    
    primes.append(vengy_prime(N))
    pprint(Eq(symbols(f'p_{N}'), int(primes[-1])))
    

    On my i5-9600KF CPU (6 cores), it takes 18.3 seconds compared to 128.1 seconds. This means the optimized code is 7 times faster.

    Modify the decimal precision mp.dps at runtime

    This is a good idea. In fact, you do not need the value to be exact! if the last digit is wrong this is not a problem since you can test if a number is a prime in polynomial time and even relatively quickly for pretty big numbers. For example, you can use the Miller Rabin primality test to check that. Note that there are deterministic algorithms not doing any assumptions on non-proven hypothesis like AKS (AFAIK significantly slower in practice and more complex to implement though). There is a catch tough: you need to quantify the error for big numbers so to know the exact range of number to search. This can be done by repeating the same algorithm multiple times while changing the rounding (and possibly the algorithms so to ensure the min/max values are actually an upper/lower bound).

    In the end, I think it is faster to use a primality test on numbers following primes[i] so to quickly find primes[i+1] (much more efficiently than your method). For relatively-small numbers like the ones you test (or even ones <10_000_000), a basic sieve of Eratosthenes is be much faster.


    Root of the problem

    The heart of the problem with your formula are the following:

    First, it require numbers to be very precise for this to work and this high precisions must increase when the searched number also increase (IDK exactly how much).

    Moreover, for a required precision p, the (best known) complexity of operations like multiplication/divisions is Θ(p log p) and the one of the exponentiation (in finite fields) is roughly Θ(p²). See this article for more information about the complexity of mathematical operations.

    Last but not least, it requires Θ(N) iterations (where N is the number of primes) so it should take about Θ(N p²) operations. This is not great, especially if p ~ Ω(N) (rather likely since you use a >66400-bit number for N is set to only 500) which would results in a Ω(N**3) complexity. With p ~ Ω(N²), it would even be Ω(N**5) which is pretty bad. Remember that a sieve of Eratosthenes roughly runs in Θ(N √N) †.

    In practice, you already have all the previous prime numbers. Thus, finding the next one is cheaper that a sieve of Eratosthenes. Indeed, the average distance between two prime number is log(N log N) = log(N) and testing if a number is prime can be done in O(√(N log N) / log N) with the trial division algorithm (much less with the Miller–Rabin test). This means the complexity of using trial division to check the next numbers in order the find the next prime numbers is O(√(N log N)). This is much better than your algorithm with a lower bound of Ω(N).

    Put it shortly, this formula is a very inefficient way of finding the next prime number.


    † This formula is simplified for sake of clarity. Indeed, I assume that the number of prime number in a range containing n numbers is O(n) while it is not. However, I think a O(1/(ln n)) factor is not very important here and IMHO it makes it harder to compare the complexities. The correct complexity for the sieve of Eratosthenes is Θ((N √N log N)). I also assume that all numbers are small enough to fit in native 64-bit numbers (computed in constant time). Indeed, the sieve of Eratosthenes would take far too much memory to run (like your algorithm which require all the previous prime numbers to be stored).