Search code examples
pythonperformancenumba

Random sample in numba


For performance reasons I often use numba and for my code I need to take a random sample without replacement. I found, that I could use the numpy.random function for that, but I noticed that it is extremely slow compared to the random.sample function. Am I doing something wrong? How could I improve the performance for the numba function? I boiled down my code to this minimal example:

import numpy as np
import numba as nb

def func2():
    List = range(100000)
    for x in range(20000):
        random.sample(List, 10)

@nb.njit()
def func3():
    Array = np.arange(100000)
    for x in range(20000):
        np.random.choice(Array, 10, False)

print(timeit(lambda: func2(), number=1))
print(timeit(lambda: func3(), number=1))
>>>0.1196
>>>20.1245

Edit: I'm now using my own sample function, which is much faster than np.random.choice.

@nb.njit()
def func4():
    for x in range(20000):
        rangeList = list(range(100000))
        result = []
        for x in range(10):
            randint = random.randint(0, len(rangeList) - 1)
            result.append(rangeList.pop(randint))
        return result
print(timeit(lambda: func4(), number=count))
>>>0.1767

Solution

  • TL;DR

    In general np.random.choice(replace=False) performs well and generally better than random.sample() (which cannot be used inside Numba JITted functions in NoPython mode anyway), but for smaller value of k it is best to use random_sample_shuffle_idx(), except when k is very small, in which case random_sample_set() should be used.

    Parallel Numba compilation may speed up execution for sufficiently large inputs, but will result in race conditions potentially invalidating the sampling, and it is therefore best avoided.

    Discussion

    A simple Numba compatible re-implementation of random.sample() can be easily written:

    import random
    import numba as nb
    import numpy as np
    
    
    @nb.njit
    def random_sample_set(arr, k=-1):
        n = arr.size
        if k < 0:
            k = arr.size
        seen = {0}
        seen.clear()
        index = np.empty(k, dtype=arr.dtype)
        for i in range(k):
            j = random.randint(i, n - 1)
            while j in seen:
                j = random.randint(0, n - 1)
            seen.add(j)
            index[i] = j
        return arr[index]
    

    This uses a temporary set() named seen which stores all the indices seen previously and avoids re-using them.

    This should always be faster than random.sample().


    A potentially much faster version of this can be written by just shuffling (a portion of a copy of) the input:

    import random
    import numba as nb
    
    
    @nb.njit
    def random_sample_shuffle(arr, k=-1):
        n = arr.size
        if k < 0:
            k = arr.size
        result = arr.copy()
        for i in range(k):
            j = random.randint(i, n - 1)
            result[i], result[j] = result[j], result[i]
        return result[:k].copy()
    

    This is faster than random_sample_set() as long as the k parameter is sufficiently large. The larger the input array, the larger the k parameter needs to be to outperform random_sample_set(). This is so because random_sample_set() will have a high collision rate of j in seen, causing the while-loop to run multiple times on average.

    Also, this has a memory footprint independent of k, which is higher than that of random_sample_set(), whose memory footprint is proportional to k.


    A slight variation of random_sample_shuffle() is random_sample_shuffle_idx() which uses indices instead of copying the input. This would have a memory footprint independent of the data type, being more efficient for larger data types, and significantly faster for small values of k and typically on par (or slightly slower) in the general case:

    import random
    import numpy as np
    import numba as nb
    
    
    @nb.njit
    def random_sample_shuffle_idx(arr, k=-1):
        n = arr.size
        if k < 0:
            k = arr.size
        index = np.arange(n)
        for i in range(k):
            j = random.randint(i, n - 1)
            index[i], index[j] = index[j], index[i]
        return arr[index[:k]]
    

    When comparing the above with np.random.choice(replace=False):

    import numpy as np
    import numba as nb
    
    
    @nb.njit
    def random_sample_choice(arr, k=-1):
        if k < 0:
            k = arr.size
        return np.random.choice(arr, k, replace=False)
    

    one would observe that this sits between random_sample_set() and random_sample_shuffle() when it comes to speed, as long as the input is sufficiently small and k is not too small.

    Of course np.random.choice() and its newer counterpart random.Generator.choice() offer a lot more functionality than these simple implementations.


    Benchmarks

    Some quick benchmarks can be generated with the following:

    funcs = random_sample_set, random_sample_shuffle
    
    
    def is_good(x):
        return len(x) == len(set(x))
    
    
    for q in range(4, 24, 4):
        n = 2 ** q
        arr = np.arange(n)
        seq = arr.tolist()
        for k in range(n // 8, n + 1, n // 8):
            print(f"n = {n}, k = {k}")
            func = random.sample
            print(f"{func.__name__:>24s} {is_good(func(seq, k))!s:>5s}", end=" ")
            %timeit -n 1 -r 1 func(seq, k)
            for func in funcs:
                print(f"{func.__name__:>24s} {is_good(func(arr, k))!s:>5s}", end=" ")
                %timeit -n 4 -r 4 func(arr, k)
    

    The most interesting results are:

    ...
    n = 65536, k = 65536
                            sample  True 41 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 22 ms ± 1 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 1 ms ± 113 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 948 µs ± 94 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 918 µs ± 67.7 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
    ...
    n = 1048576, k = 131072
                            sample  True 136 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 19 ms ± 1.84 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 5.85 ms ± 303 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 6.95 ms ± 445 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 26.1 ms ± 1.93 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
    ...
    n = 1048576, k = 917504
                            sample  True 916 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 313 ms ± 47.6 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 29.4 ms ± 1.87 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 32.8 ms ± 1.55 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 28.2 ms ± 1.06 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
    ...
    

    And the small k regime (which is the use-case of the question):

    for q in range(4, 28, 4):
        n = 2 ** q
        arr = np.arange(n)
        seq = arr.tolist()
        k = 16
        print(f"n = {n}, k = {k}")
        func = random.sample
        print(f"{func.__name__:>24s} {is_good(func(seq, k))!s:>5s}", end=" ")
        %timeit -n 1 -r 1 func(seq, k)
        for func in funcs:
            print(f"{func.__name__:>24s} {is_good(func(arr, k))!s:>5s}", end=" ")
            %timeit -n 4 -r 4 func(arr, k)
    
    n = 16, k = 16
                            sample  True 39.1 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 5.11 µs ± 2.95 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 2.62 µs ± 1.67 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 2.51 µs ± 1.47 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 2.39 µs ± 1.47 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
    n = 256, k = 16
                            sample  True 43.7 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 3.67 µs ± 2.36 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 2.59 µs ± 1.72 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True The slowest run took 4.44 times longer than the fastest. This could mean that an intermediate result is being cached.
    2.8 µs ± 2.16 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 5.47 µs ± 1.8 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
    n = 4096, k = 16
                            sample  True 33.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 3.53 µs ± 1.73 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 4.2 µs ± 1.81 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 3.23 µs ± 1.46 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 51.7 µs ± 4.82 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
    n = 65536, k = 16
                            sample  True 58.9 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 4.15 µs ± 2.75 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 35.3 µs ± 7.99 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 15.1 µs ± 5.03 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True The slowest run took 5.93 times longer than the fastest. This could mean that an intermediate result is being cached.
    2.3 ms ± 1.61 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
    n = 1048576, k = 16
                            sample  True 48.2 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 3.89 µs ± 2.01 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 1.87 ms ± 163 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 810 µs ± 195 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 30.6 ms ± 2.41 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
    n = 16777216, k = 16
                            sample  True 70.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
                 random_sample_set  True 5.15 µs ± 3.65 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
             random_sample_shuffle  True 103 ms ± 1.84 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
         random_sample_shuffle_idx  True 75.2 ms ± 3.31 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
              random_sample_choice  True 863 ms ± 77.3 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)