Search code examples
pythonnumpycythonpde

Cythonize a partial differential equation integrator


I am trying to speed up a finite differences integrator for a partial differential equation using Cython. I am not sure what I need to do in order for Cython to work correctly with the numpy arrays.

The diffusion term function that I use is

def laplacian(var, dh2):
    """ (1D array, dx^2) -> laplacian(1D array)
    periodic_laplacian_1D_4th_order
    Implementing the 4th order 1D laplacian with periodic condition
    """
    lap = numpy.zeros_like(var)
    lap[1:]    = (4.0/3.0)*var[:-1]
    lap[0]     = (4.0/3.0)*var[1]
    lap[:-1]  += (4.0/3.0)*var[1:]
    lap[-1]   += (4.0/3.0)*var[0]
    lap       += (-5.0/2.0)*var

    lap[2:]   += (-1.0/12.0)*var[:-2]
    lap[:2]   += (-1.0/12.0)*var[-2:]
    lap[:-2]  += (-1.0/12.0)*var[2:]
    lap[-2:]  += (-1.0/12.0)*var[:2]

    return lap / dh2

And the rhs of the equations of the model are

from derivatives import laplacian

def dbdt(b,w,p,m,d,dx2):
    """ db/dt of Modified Klausmeier """
    return w*b**2 - m*b + laplacian(b,dx2)

def dwdt(b,w,p,m,d,dx2):
    """ dw/dt of Modified Klausmeier """
    return p - w - w*b**2 + d*laplacian(b,dx2)

How can I optimize those functions using Cython?

I have a repository on Github for my working code, that integrates the Gray-Scott model - Gray-Scott model integrator.


Solution

  • To use Cython efficiently, you should make all loops explicit and make sure cython -a shows as few Python calls as possible. A first try would be:

    import numpy as np
    cimport numpy as np
    cimport cython
    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    def laplacian(double [::1] var, double dh2):
        """ (1D array, dx^2) -> laplacian(1D array)
        periodic_laplacian_1D_4th_order
        Implementing the 4th order 1D laplacian with periodic condition
        """
        cdef int n = var.shape[0]
        cdef double[::1] lap = np.zeros(n)
        cdef int i
        for i in range(0, n-1):
            lap[1+i] = (4.0/3.0)*var[i]
        lap[0]     = (4.0/3.0)*var[1]
        for i in range(0, n-1):
            lap[i]  += (4.0/3.0)*var[1+i]
        lap[n-1]   += (4.0/3.0)*var[0]
        for i in range(0, n):
            lap[i]       += (-5.0/2.0)*var[i]
    
        for i in range(0, n-2):
            lap[2+i]   += (-1.0/12.0)*var[i]
        for i in range(0, 2):
            lap[i]   += (-1.0/12.0)*var[n - 2 + i]
        for i in range(0, n-2):
            lap[i]   += (-1.0/12.0)*var[i+2]
        for i in range(0, 2):
            lap[n-2+i]  += (-1.0/12.0)*var[i]
        for i in range(0, n):
            lap[i]  /= dh2
        return lap
    

    Now this gives you:

    $ python -m timeit -s 'import numpy as np; from lap import laplacian; var = np.random.rand(1000000); dh2 = .01' 'laplacian(var, dh2)'
    100 loops, best of 3: 11.5 msec per loop
    

    while the NumPy code gave:

    100 loops, best of 3: 18.5 msec per loop
    

    Note that the Cython could be further optimized by merging loops etc.

    I also tried with a customized (i.e. not committed in master) version of Pythran and without changing the original Python code, I had the same speedup as the Cython version, without the hassle of converting the code:

    #pythran export laplacian(float [], float)
    import numpy
    def laplacian(var, dh2):
        """ (1D array, dx^2) -> laplacian(1D array)
        periodic_laplacian_1D_4th_order
        Implementing the 4th order 1D laplacian with periodic condition
        """
        lap = numpy.zeros_like(var)
        lap[1:]    = (4.0/3.0)*var[:-1]
        lap[0]     = (4.0/3.0)*var[1]
        lap[:-1]  += (4.0/3.0)*var[1:]
        lap[-1]   += (4.0/3.0)*var[0]
        lap       += (-5.0/2.0)*var
    
        lap[2:]   += (-1.0/12.0)*var[:-2]
        lap[:2]   += (-1.0/12.0)*var[-2:]
        lap[:-2]  += (-1.0/12.0)*var[2:]
        lap[-2:]  += (-1.0/12.0)*var[:2]
    
        return lap / dh2
    

    Converted with:

    $ pythran lap.py -O3
    

    And I get:

    100 loops, best of 3: 11.6 msec per loop