Search code examples
pythonscipynumbagamma-function

Why is this log gamma numba function slower than scipy for large arrays, but faster for single values?


I have a function to calculate the log gamma function that I am decorating with numba.njit.

import numpy as np
from numpy import log
from scipy.special import gammaln
from numba import njit

coefs = np.array([
    57.1562356658629235, -59.5979603554754912,
    14.1360979747417471, -0.491913816097620199,
    .339946499848118887e-4, .465236289270485756e-4,
    -.983744753048795646e-4, .158088703224912494e-3,
    -.210264441724104883e-3, .217439618115212643e-3,
    -.164318106536763890e-3, .844182239838527433e-4,
    -.261908384015814087e-4, .368991826595316234e-5
])

@njit(fastmath=True)
def gammaln_nr(z):
    """Numerical Recipes 6.1"""
    y = z
    tmp = z + 5.24218750000000000
    tmp = (z + 0.5) * log(tmp) - tmp
    ser = np.ones_like(y) * 0.999999999999997092

    n = coefs.shape[0]
    for j in range(n):
        y = y + 1
        ser = ser + coefs[j] / y

    out = tmp + log(2.5066282746310005 * ser / z)
    return out

When I use gammaln_nr for a large array, say np.linspace(0.001, 100, 10**7), my run time is about 7X slower than scipy (see code in appendix below). However, if I run for any individual value, my numba function is always about 2X faster. How is this happening?

z = 11.67
%timeit gammaln_nr(z)
%timeit gammaln(z)
>>> 470 ns ± 29.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
>>> 1.22 µs ± 28.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

My intuition is that if my function is faster for one value, it should be faster for an array of values. Of course, this may not be the case because I don't know whether numba is using SIMD instructions or some other sort of vectorization, whereas scipy may be.

Appendix


import matplotlib.pyplot as plt
import seaborn as sns

n_trials = 8
scipy_times = np.zeros(n_trials)
fastats_times = np.zeros(n_trials)

for i in range(n_trials):
    zs = np.linspace(0.001, 100, 10**i) # evaluate gammaln over this range

    # dont take first timing - this is just compilation
    start = time.time()
    gammaln_nr(zs)
    end = time.time()

    start = time.time()
    gammaln_nr(zs)
    end = time.time()
    fastats_times[i] = end - start

    start = time.time()
    gammaln(zs)
    end = time.time()
    scipy_times[i] = end - start

fig, ax = plt.subplots(figsize=(12,8))
sns.lineplot(np.logspace(0, n_trials-1, n_trials), fastats_times, label="numba");
sns.lineplot(np.logspace(0, n_trials-1, n_trials), scipy_times, label="scipy");
ax.set(xscale="log");
ax.set_xlabel("Array Size", fontsize=15);
ax.set_ylabel("Execution Time (s)", fontsize=15);
ax.set_title("Execution Time of Log Gamma");

enter image description here


Solution

  • Implementing gammaln in Numba

    It can be quite some work to reimplement some often used functions, not only to reach the performance, but also to get a well defined level of precision. So the direct way would be to simply wrap a working implementation.

    In case of gammaln scipy- calls a C-implemntation of this function. Therefore the speed of the scipy-implementation also depends on the compiler and compilerflags used when compiling the scipy dependencies.

    It is also not very suprising that the performance results for one value can differ quite a lot from the results of larger arrays. In the first case the calling overhead (including type conversions, input checking,...) dominates, in the second case the performance of the implementation gets more and more important.

    Improving your implementation

    • Write explicit loops. In Numba vectorized operations are expanded to loops and after that Numba tries to join the loops. It is often better to write out and join this loops manually.
    • Think of the differences of basic arithmetic implementations. Python always checks for a division by 0 and raises an exception in such a case, which is very costly. Numba also uses this behaviour by default, but you can also switch to Numpy-error checking. In this case a division by 0 results in NaN. The way NaN and Inf -0/+0 is handled in further calculations is also influenced by the fast-math flag.

    Code

    import numpy as np
    from numpy import log
    from scipy.special import gammaln
    from numba import njit
    import numba as nb
    
    @njit(fastmath=True,error_model='numpy')
    def gammaln_nr(z):
        """Numerical Recipes 6.1"""
        #Don't use global variables.. (They only can be changed if you recompile the function)
        coefs = np.array([
        57.1562356658629235, -59.5979603554754912,
        14.1360979747417471, -0.491913816097620199,
        .339946499848118887e-4, .465236289270485756e-4,
        -.983744753048795646e-4, .158088703224912494e-3,
        -.210264441724104883e-3, .217439618115212643e-3,
        -.164318106536763890e-3, .844182239838527433e-4,
        -.261908384015814087e-4, .368991826595316234e-5])
    
        out=np.empty(z.shape[0])
    
    
        for i in range(z.shape[0]):
          y = z[i]
          tmp = z[i] + 5.24218750000000000
          tmp = (z[i] + 0.5) * np.log(tmp) - tmp
          ser = 0.999999999999997092
    
          n = coefs.shape[0]
          for j in range(n):
              y = y + 1.
              ser = ser + coefs[j] / y
    
          out[i] = tmp + log(2.5066282746310005 * ser / z[i])
        return out
    
    @njit(fastmath=True,error_model='numpy',parallel=True)
    def gammaln_nr_p(z):
        """Numerical Recipes 6.1"""
        #Don't use global variables.. (They only can be changed if you recompile the function)
        coefs = np.array([
        57.1562356658629235, -59.5979603554754912,
        14.1360979747417471, -0.491913816097620199,
        .339946499848118887e-4, .465236289270485756e-4,
        -.983744753048795646e-4, .158088703224912494e-3,
        -.210264441724104883e-3, .217439618115212643e-3,
        -.164318106536763890e-3, .844182239838527433e-4,
        -.261908384015814087e-4, .368991826595316234e-5])
    
        out=np.empty(z.shape[0])
    
    
        for i in nb.prange(z.shape[0]):
          y = z[i]
          tmp = z[i] + 5.24218750000000000
          tmp = (z[i] + 0.5) * np.log(tmp) - tmp
          ser = 0.999999999999997092
    
          n = coefs.shape[0]
          for j in range(n):
              y = y + 1.
              ser = ser + coefs[j] / y
    
          out[i] = tmp + log(2.5066282746310005 * ser / z[i])
        return out
    
    
    import matplotlib.pyplot as plt
    import seaborn as sns
    import time
    
    n_trials = 8
    scipy_times = np.zeros(n_trials)
    fastats_times = np.zeros(n_trials)
    fastats_times_p = np.zeros(n_trials)
    
    for i in range(n_trials):
        zs = np.linspace(0.001, 100, 10**i) # evaluate gammaln over this range
    
        # dont take first timing - this is just compilation
        start = time.time()
        arr_1=gammaln_nr(zs)
        end = time.time()
    
        start = time.time()
        arr_1=gammaln_nr(zs)
        end = time.time()
        fastats_times[i] = end - start
    
        start = time.time()
        arr_3=gammaln_nr_p(zs)
        end = time.time()
        fastats_times_p[i] = end - start
        start = time.time()
    
        start = time.time()
        arr_3=gammaln_nr_p(zs)
        end = time.time()
        fastats_times_p[i] = end - start
        start = time.time()
    
        arr_2=gammaln(zs)
        end = time.time()
        scipy_times[i] = end - start
        print(np.allclose(arr_1,arr_2))
        print(np.allclose(arr_1,arr_3))
    
    fig, ax = plt.subplots(figsize=(12,8))
    sns.lineplot(np.logspace(0, n_trials-1, n_trials), fastats_times, label="numba");
    sns.lineplot(np.logspace(0, n_trials-1, n_trials), fastats_times_p, label="numba_parallel");
    sns.lineplot(np.logspace(0, n_trials-1, n_trials), scipy_times, label="scipy");
    ax.set(xscale="log");
    ax.set_xlabel("Array Size", fontsize=15);
    ax.set_ylabel("Execution Time (s)", fontsize=15);
    ax.set_title("Execution Time of Log Gamma");
    fig.show()