Search code examples
pythonalgorithmoptimizationmathematical-optimizationprimes

Python nested lists search optimization


I have a search and test problem : in the list of prime numbers from 2 to 100k, we're searching the first set of 5 with the following criteria :

  • p1 < p2 < p3 < p4 < p5
  • any combination of 2 primes from the solution (3 and 7 => 37 and 73) must also be a prime
  • sum(p1..p5) is the smallest possible sum of primes satisfying criteria, and is above 100k

I can totally code such a thing, but i have a severe optimization problem : my code is super duper slow. I have a list of primes under 100k, a list of primes over 100k, and a primality test which works well, but i do not see how to optimize that to obtain a result in a correct time.

For a basic idea :

  • the list of all primes under 100k contains 9592 items
  • the list of all primes under 1 billion contains approximately 51 million lines
  • i have the list of all primes under 1 billion, by length

Thanks for the help


Solution

  • Here is a version that computes the minimal combination of the prime numbers with restrictions as stated in the question.

    The majority of runtime is spent in pre-computing all valid combinations of prime numbers. On my computer (AMD 5700X) this runs in 1 minute 20 seconds:

    import numpy as np
    from numba import njit, prange
    
    
    @njit
    def prime(a):
        if a < 2:
            return False
        for x in range(2, int(a**0.5) + 1):
            if a % x == 0:
                return False
        return True
    
    
    @njit
    def str_to_int(s):
        final_index, result = len(s) - 1, 0
        for i, v in enumerate(s):
            result += (ord(v) - 48) * (10 ** (final_index - i))
        return result
    
    
    @njit
    def generate_primes(n):
        out = []
        for i in range(3, n + 1):
            if prime(i):
                out.append(i)
        return out
    
    
    @njit(parallel=True)
    def get_comb(n=100_000):
        # generate all primes < n
        primes = generate_primes(n)
        n_primes = len(primes)
    
        # generate all valid combinations of primes
        combs = np.zeros((n_primes, n_primes), dtype=np.uint8)
    
        for i in prange(n_primes):
            for j in prange(i + 1, n_primes):
                p1, p2 = primes[i], primes[j]
    
                c1 = str_to_int(f"{p1}{p2}")
                c2 = str_to_int(f"{p2}{p1}")
    
                if not prime(c1) or not prime(c2):
                    continue
    
                combs[i, j] = 1
    
        all_combs = []
    
        for i_p1 in prange(0, n_primes):
            for i_p2 in prange(i_p1 + 1, n_primes):
                if combs[i_p1, i_p2] == 0:
                    continue
                for i_p3 in prange(i_p2 + 1, n_primes):
                    if combs[i_p1, i_p3] == 0:
                        continue
                    if combs[i_p2, i_p3] == 0:
                        continue
                    for i_p4 in prange(i_p3 + 1, n_primes):
                        if combs[i_p1, i_p4] == 0:
                            continue
                        if combs[i_p2, i_p4] == 0:
                            continue
                        if combs[i_p3, i_p4] == 0:
                            continue
                        for i_p5 in prange(i_p4 + 1, n_primes):
                            if combs[i_p1, i_p5] == 0:
                                continue
                            if combs[i_p2, i_p5] == 0:
                                continue
                            if combs[i_p3, i_p5] == 0:
                                continue
                            if combs[i_p4, i_p5] == 0:
                                continue
    
                            p1, p2, p3, p4, p5 = (
                                primes[i_p1],
                                primes[i_p2],
                                primes[i_p3],
                                primes[i_p4],
                                primes[i_p5],
                            )
    
                            ccomb = np.array([p1, p2, p3, p4, p5], dtype=np.int64)
                            if np.sum(ccomb) < n:
                                continue
    
                            all_combs.append(ccomb)
                            print(ccomb)
                            break
    
        return all_combs
    
    
    all_combs = np.array(get_comb())
    print()
    print("Minimal combination:")
    print(all_combs[np.sum(all_combs, axis=1).argmin()])
    

    Prints:

    [    3 28277 44111 70241 78509]
    [    7    61 25939 26893 63601]
    [    7    61 25939 61417 63601]                     
    [    7    61 25939 61471 86959]                     
    [    7  2467 24847 55213 92593]                     
    [    7  3361 30757 49069 57331]                
    
    ...
    
    [ 1993 12823 35911 69691 87697]
    [ 2287  4483  6793 27823 67723]
    [ 3541  9187 38167 44257 65677]
                                                        
    Minimal combination:                                
    [   13   829  9091 17929 72739]
                                                        
    real    1m20,599s                                   
    user    0m0,011s                                    
    sys     0m0,008s