Search code examples
pythonnumpyscipycupy

How to do a cubic spline interpolation on python with CuPy?


I am writing code by using GPU to keep doing cubic spline interpolation many times. I know how to do it on numpy like using

scipy.interpolate.splrep

or

scipy.interpolate.interp1d(kind='cubic')

The interp1d is what I am using now for numpy arrays. But I need to run them on CuPy.

But how should I do it on CuPy? I have a x-values and y-values. And I also have an array with new x-values. The code I am writing now is going to calculate the new y-values for the new x-values.


Solution

  • I can answer this question myself now. I finished the interpolation algorithm. As ev-br's answer's suggestion, I just re-write the scipy.interpolate.CubicSpline for cupy.

    scipy.interpolate.CubicSpline contained a lot of functions which is not helpful if we just need an interpolation function.

    The class CubicSpline has a parent class scipy.interpolate.PPoly, which also contained a lot of unnecessary functions if you only want an interpolation function. After a clear clean, I only used the classes _PPolyBase, solve_banded(), and prepare_input().

    The hardest part is the function evaluate() which is written in cython. Cython doesn't support cupy, so I used the numba which supports cuda to accelerate the loops' speed.

    The head of the function evaluate() should be like:

    @cuda.jit('void(complex128[:,:,:], float64[:], float64[:], complex128[:,:])')
    def evaluate(c, x, xp, out):
    

    There is an important thing that needs to notice which is the evaluation function is not a thread-safe function.

    Only the first loop in evaluate() which is:

    for ip in range(len(xp)):
        xval = xp[IP]
        ......
    

    can use cuda.grid(1) and cuda.gridsize(1)

    Also, I combine evaluate_poly1() and find_interval_descending() inside of evaluate() for better fit the numbe's cuda support.

    The speed is super faster, which is about 3 to 4 times faster than the original scipy function.

    The code can be found here: https://github.com/GavinJiacheng/Interpolation_CUPY