For performance reasons I often use numba and for my code I need to take a random sample without replacement. I found, that I could use the numpy.random function for that, but I noticed that it is extremely slow compared to the random.sample function. Am I doing something wrong? How could I improve the performance for the numba function? I boiled down my code to this minimal example:
import numpy as np
import numba as nb
def func2():
List = range(100000)
for x in range(20000):
random.sample(List, 10)
@nb.njit()
def func3():
Array = np.arange(100000)
for x in range(20000):
np.random.choice(Array, 10, False)
print(timeit(lambda: func2(), number=1))
print(timeit(lambda: func3(), number=1))
>>>0.1196
>>>20.1245
Edit: I'm now using my own sample function, which is much faster than np.random.choice.
@nb.njit()
def func4():
for x in range(20000):
rangeList = list(range(100000))
result = []
for x in range(10):
randint = random.randint(0, len(rangeList) - 1)
result.append(rangeList.pop(randint))
return result
print(timeit(lambda: func4(), number=count))
>>>0.1767
In general np.random.choice(replace=False)
performs well and generally better than random.sample()
(which cannot be used inside Numba JITted functions in NoPython mode anyway), but for smaller value of k
it is best to use random_sample_shuffle_idx()
, except when k
is very small, in which case random_sample_set()
should be used.
Parallel Numba compilation may speed up execution for sufficiently large inputs, but will result in race conditions potentially invalidating the sampling, and it is therefore best avoided.
A simple Numba compatible re-implementation of random.sample()
can be easily written:
import random
import numba as nb
import numpy as np
@nb.njit
def random_sample_set(arr, k=-1):
n = arr.size
if k < 0:
k = arr.size
seen = {0}
seen.clear()
index = np.empty(k, dtype=arr.dtype)
for i in range(k):
j = random.randint(i, n - 1)
while j in seen:
j = random.randint(0, n - 1)
seen.add(j)
index[i] = j
return arr[index]
This uses a temporary set()
named seen
which stores all the indices seen previously and avoids re-using them.
This should always be faster than random.sample()
.
A potentially much faster version of this can be written by just shuffling (a portion of a copy of) the input:
import random
import numba as nb
@nb.njit
def random_sample_shuffle(arr, k=-1):
n = arr.size
if k < 0:
k = arr.size
result = arr.copy()
for i in range(k):
j = random.randint(i, n - 1)
result[i], result[j] = result[j], result[i]
return result[:k].copy()
This is faster than random_sample_set()
as long as the k
parameter is sufficiently large.
The larger the input array, the larger the k
parameter needs to be to outperform random_sample_set()
.
This is so because random_sample_set()
will have a high collision rate of j
in seen
, causing the while
-loop to run multiple times on average.
Also, this has a memory footprint independent of k
, which is higher than that of random_sample_set()
, whose memory footprint is proportional to k
.
A slight variation of random_sample_shuffle()
is random_sample_shuffle_idx()
which uses indices instead of copying the input.
This would have a memory footprint independent of the data type, being more efficient for larger data types, and significantly faster for small values of k
and typically on par (or slightly slower) in the general case:
import random
import numpy as np
import numba as nb
@nb.njit
def random_sample_shuffle_idx(arr, k=-1):
n = arr.size
if k < 0:
k = arr.size
index = np.arange(n)
for i in range(k):
j = random.randint(i, n - 1)
index[i], index[j] = index[j], index[i]
return arr[index[:k]]
When comparing the above with np.random.choice(replace=False)
:
import numpy as np
import numba as nb
@nb.njit
def random_sample_choice(arr, k=-1):
if k < 0:
k = arr.size
return np.random.choice(arr, k, replace=False)
one would observe that this sits between random_sample_set()
and random_sample_shuffle()
when it comes to speed, as long as the input is sufficiently small and k
is not too small.
Of course np.random.choice()
and its newer counterpart random.Generator.choice()
offer a lot more functionality than these simple implementations.
Some quick benchmarks can be generated with the following:
funcs = random_sample_set, random_sample_shuffle
def is_good(x):
return len(x) == len(set(x))
for q in range(4, 24, 4):
n = 2 ** q
arr = np.arange(n)
seq = arr.tolist()
for k in range(n // 8, n + 1, n // 8):
print(f"n = {n}, k = {k}")
func = random.sample
print(f"{func.__name__:>24s} {is_good(func(seq, k))!s:>5s}", end=" ")
%timeit -n 1 -r 1 func(seq, k)
for func in funcs:
print(f"{func.__name__:>24s} {is_good(func(arr, k))!s:>5s}", end=" ")
%timeit -n 4 -r 4 func(arr, k)
The most interesting results are:
...
n = 65536, k = 65536
sample True 41 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 22 ms ± 1 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 1 ms ± 113 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 948 µs ± 94 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 918 µs ± 67.7 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
...
n = 1048576, k = 131072
sample True 136 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 19 ms ± 1.84 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 5.85 ms ± 303 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 6.95 ms ± 445 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 26.1 ms ± 1.93 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
...
n = 1048576, k = 917504
sample True 916 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 313 ms ± 47.6 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 29.4 ms ± 1.87 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 32.8 ms ± 1.55 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 28.2 ms ± 1.06 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
...
And the small k
regime (which is the use-case of the question):
for q in range(4, 28, 4):
n = 2 ** q
arr = np.arange(n)
seq = arr.tolist()
k = 16
print(f"n = {n}, k = {k}")
func = random.sample
print(f"{func.__name__:>24s} {is_good(func(seq, k))!s:>5s}", end=" ")
%timeit -n 1 -r 1 func(seq, k)
for func in funcs:
print(f"{func.__name__:>24s} {is_good(func(arr, k))!s:>5s}", end=" ")
%timeit -n 4 -r 4 func(arr, k)
n = 16, k = 16
sample True 39.1 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 5.11 µs ± 2.95 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 2.62 µs ± 1.67 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 2.51 µs ± 1.47 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 2.39 µs ± 1.47 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 256, k = 16
sample True 43.7 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 3.67 µs ± 2.36 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 2.59 µs ± 1.72 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True The slowest run took 4.44 times longer than the fastest. This could mean that an intermediate result is being cached.
2.8 µs ± 2.16 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 5.47 µs ± 1.8 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 4096, k = 16
sample True 33.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 3.53 µs ± 1.73 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 4.2 µs ± 1.81 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 3.23 µs ± 1.46 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 51.7 µs ± 4.82 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 65536, k = 16
sample True 58.9 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 4.15 µs ± 2.75 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 35.3 µs ± 7.99 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 15.1 µs ± 5.03 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True The slowest run took 5.93 times longer than the fastest. This could mean that an intermediate result is being cached.
2.3 ms ± 1.61 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 1048576, k = 16
sample True 48.2 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 3.89 µs ± 2.01 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 1.87 ms ± 163 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 810 µs ± 195 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 30.6 ms ± 2.41 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 16777216, k = 16
sample True 70.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 5.15 µs ± 3.65 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 103 ms ± 1.84 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 75.2 ms ± 3.31 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 863 ms ± 77.3 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)