Search code examples
pythonpython-3.xnumpyprimessieve-of-eratosthenes

How to optimize NumPy Sieve of Eratosthenes?


I have made my own Sieve of Eratosthenes implementation in NumPy. I am sure you all know it is for finding all primes below a number, so I won't explain anything further.

Code:

import numpy as np

def primes_sieve(n):
    primes = np.ones(n+1, dtype=bool)
    primes[:2] = False
    primes[4::2] = False
    for i in range(3, int(n**0.5)+1, 2):
        if primes[i]:
            primes[i*i::i] = False

    return np.where(primes)[0]

As you can see I have already made some optimizations, first all primes are odd except for 2, so I set all multiples of 2 to False and only brute-force odd numbers.

Second I only looped through numbers up to the floor of the square root, because all composite numbers after the square root would be eliminated by being a multiple of a prime number below the square root.

But it isn't optimal, because it loops through all odd numbers below the limit, and not all odd numbers are prime. And as the number grows larger, primes become more sparse, so there are lots of redundant iterations.

So if the list of candidates is changed dynamically, in such a way that composite numbers already identified wouldn't even ever be iterated upon, so that only prime numbers are looped through, there won't be any wasteful iterations, thus the algorithm would be optimal.

I have written a crude implementation of the optimized version:

def primes_sieve_opt(n):
    primes = np.ones(n+1, dtype=bool)
    primes[:2] = False
    primes[4::2] = False
    limit = int(n**0.5)+1
    i = 2
    while i < limit:
        primes[i*i::i] = False
        i += 1 + primes[i+1:].argmax()

    return np.where(primes)[0]

But it is much slower than the unoptimized version:

In [92]: %timeit primes_sieve(65536)
271 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [102]: %timeit primes_sieve_opt(65536)
309 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

My idea is simple, by getting the next index of True, I can ensure all primes are covered and only primes are processed.

However np.argmax is slow in this regard. I Google searched "how to find the index of the next True value in NumPy array" (without quotes), and I found several StackOverflow questions that are slightly relevant but ultimately doesn't answer my question.

For example, numpy get index where value is true and Numpy first occurrence of value greater than existing value.

I am not trying to find all indexes where True, and it is extremely stupid to do that, I need to find the next True value, get its index and immediately stop looping, there are only bools.

How can I optimize this?


Edit

If anyone is interested, I have optimized my algorithm further:

import numba
import numpy as np

@numba.jit(nopython=True, parallel=True, fastmath=True, forceobj=False)
def prime_sieve(n: int) -> np.ndarray:
    primes = np.full(n + 1, True)
    primes[:2] = False
    primes[4::2] = False
    primes[9::6] = False
    limit = int(n**0.5) + 1
    for i in range(5, limit, 6):
        if primes[i]:
            primes[i * i :: 2 * i] = False

    for i in range(7, limit, 6):
        if primes[i]:
            primes[i * i :: 2 * i] = False

    return np.flatnonzero(primes)

I used numba to speed things up. And since all primes except 2 and 3 are either 6k+1 or 6k-1, this makes things even faster.


Solution

  • My idea is simple, by getting the next index of True, I can ensure all primes are covered and only primes are processed.

    Some profiling suggests you could get at best a 0.2% speedup this way.

    For large values of N, the vast majority of time is spent doing primes[i*i::i] = False.

    Here's the output of line_profiler running on the first hundred million primes:

    Timer unit: 1e-09 s
    
    Total time: 1.04878 s
    File: /tmp/ipykernel_22262/2557137730.py
    Function: primes_sieve at line 3
    
    Line #      Hits         Time  Per Hit   % Time  Line Contents
    ==============================================================
         3                                           def primes_sieve(n):
         4         1   14264754.0 14264754.0      1.4      primes = np.ones(n+1, dtype=bool)
         5         1      12394.0  12394.0      0.0      primes[:2] = False
         6         1   16238905.0 16238905.0      1.5      primes[4::2] = False
         7      4999    1309955.0    262.0      0.1      for i in range(3, int(n**0.5)+1, 2):
         8      3771    1507909.0    399.9      0.1          if primes[i]:
         9      1228  909007228.0 740233.9     86.7              primes[i*i::i] = False
        10                                           
        11         1  106434647.0 106434647.0     10.1      return np.where(primes)[0]
    

    If you skipped more values of i, you could avoid the time spent on the lines for i in range(3, int(n**0.5)+1, 2): and if primes[i]:. But you wouldn't avoid the time spent in primes[i*i::i] = False. Since the program spends 0.1% in each of those, you could save at most 0.2% of your execution time.