I have made my own Sieve of Eratosthenes implementation in NumPy. I am sure you all know it is for finding all primes below a number, so I won't explain anything further.
Code:
import numpy as np
def primes_sieve(n):
primes = np.ones(n+1, dtype=bool)
primes[:2] = False
primes[4::2] = False
for i in range(3, int(n**0.5)+1, 2):
if primes[i]:
primes[i*i::i] = False
return np.where(primes)[0]
As you can see I have already made some optimizations, first all primes are odd except for 2, so I set all multiples of 2 to False
and only brute-force odd numbers.
Second I only looped through numbers up to the floor of the square root, because all composite numbers after the square root would be eliminated by being a multiple of a prime number below the square root.
But it isn't optimal, because it loops through all odd numbers below the limit, and not all odd numbers are prime. And as the number grows larger, primes become more sparse, so there are lots of redundant iterations.
So if the list of candidates is changed dynamically, in such a way that composite numbers already identified wouldn't even ever be iterated upon, so that only prime numbers are looped through, there won't be any wasteful iterations, thus the algorithm would be optimal.
I have written a crude implementation of the optimized version:
def primes_sieve_opt(n):
primes = np.ones(n+1, dtype=bool)
primes[:2] = False
primes[4::2] = False
limit = int(n**0.5)+1
i = 2
while i < limit:
primes[i*i::i] = False
i += 1 + primes[i+1:].argmax()
return np.where(primes)[0]
But it is much slower than the unoptimized version:
In [92]: %timeit primes_sieve(65536)
271 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [102]: %timeit primes_sieve_opt(65536)
309 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
My idea is simple, by getting the next index of True
, I can ensure all primes are covered and only primes are processed.
However np.argmax
is slow in this regard. I Google searched "how to find the index of the next True value in NumPy array" (without quotes), and I found several StackOverflow questions that are slightly relevant but ultimately doesn't answer my question.
For example, numpy get index where value is true and Numpy first occurrence of value greater than existing value.
I am not trying to find all indexes where True
, and it is extremely stupid to do that, I need to find the next True
value, get its index and immediately stop looping, there are only bool
s.
How can I optimize this?
If anyone is interested, I have optimized my algorithm further:
import numba
import numpy as np
@numba.jit(nopython=True, parallel=True, fastmath=True, forceobj=False)
def prime_sieve(n: int) -> np.ndarray:
primes = np.full(n + 1, True)
primes[:2] = False
primes[4::2] = False
primes[9::6] = False
limit = int(n**0.5) + 1
for i in range(5, limit, 6):
if primes[i]:
primes[i * i :: 2 * i] = False
for i in range(7, limit, 6):
if primes[i]:
primes[i * i :: 2 * i] = False
return np.flatnonzero(primes)
I used numba
to speed things up. And since all primes except 2 and 3 are either 6k+1 or 6k-1, this makes things even faster.
My idea is simple, by getting the next index of True, I can ensure all primes are covered and only primes are processed.
Some profiling suggests you could get at best a 0.2% speedup this way.
For large values of N, the vast majority of time is spent doing primes[i*i::i] = False
.
Here's the output of line_profiler running on the first hundred million primes:
Timer unit: 1e-09 s
Total time: 1.04878 s
File: /tmp/ipykernel_22262/2557137730.py
Function: primes_sieve at line 3
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def primes_sieve(n):
4 1 14264754.0 14264754.0 1.4 primes = np.ones(n+1, dtype=bool)
5 1 12394.0 12394.0 0.0 primes[:2] = False
6 1 16238905.0 16238905.0 1.5 primes[4::2] = False
7 4999 1309955.0 262.0 0.1 for i in range(3, int(n**0.5)+1, 2):
8 3771 1507909.0 399.9 0.1 if primes[i]:
9 1228 909007228.0 740233.9 86.7 primes[i*i::i] = False
10
11 1 106434647.0 106434647.0 10.1 return np.where(primes)[0]
If you skipped more values of i
, you could avoid the time spent on the lines for i in range(3, int(n**0.5)+1, 2):
and if primes[i]:
. But you wouldn't avoid the time spent in primes[i*i::i] = False
. Since the program spends 0.1% in each of those, you could save at most 0.2% of your execution time.