Search code examples
pythonpytorchinterpolationnumeric

How to do cubic spline interpolation and integration in Pytorch


In Pytorch, is there cubic spline interpolation similar to Scipy's? Given 1D input tensors x and y, I want to interpolate through those points and evaluate them at xs to obtain ys. Also, I want an integrator function that finds Ys, the integral of the spline interpolation from x[0] to xs.


Solution

  • Here is a gist I made doing this with Cubic Hermite Splines in Pytorch efficiently and with autograd support.

    For convenience, I'll also put the code here.

    import torch as T
    
    def h_poly_helper(tt):
      A = T.tensor([
          [1, 0, -3, 2],
          [0, 1, -2, 1],
          [0, 0, 3, -2],
          [0, 0, -1, 1]
          ], dtype=tt[-1].dtype)
      return [
        sum( A[i, j]*tt[j] for j in range(4) )
        for i in range(4) ]
    
    def h_poly(t):
      tt = [ None for _ in range(4) ]
      tt[0] = 1
      for i in range(1, 4):
        tt[i] = tt[i-1]*t
      return h_poly_helper(tt)
    
    def H_poly(t):
      tt = [ None for _ in range(4) ]
      tt[0] = t
      for i in range(1, 4):
        tt[i] = tt[i-1]*t*i/(i+1)
      return h_poly_helper(tt)
    
    def interp_func(x, y):
      "Returns integral of interpolating function"
      if len(y)>1:
        m = (y[1:] - y[:-1])/(x[1:] - x[:-1])
        m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]])
      def f(xs):
        if len(y)==1: # in the case of 1 point, treat as constant function
          return y[0] + T.zeros_like(xs)
        I = T.searchsorted(x[1:], xs)
        dx = (x[I+1]-x[I])
        hh = h_poly((xs-x[I])/dx)
        return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx
      return f
    
    def interp(x, y, xs):
      return interp_func(x,y)(xs)
    
    def integ_func(x, y):
      "Returns interpolating function"
      if len(y)>1:
        m = (y[1:] - y[:-1])/(x[1:] - x[:-1])
        m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]])
        Y = T.zeros_like(y)
        Y[1:] = (x[1:]-x[:-1])*(
            (y[:-1]+y[1:])/2 + (m[:-1] - m[1:])*(x[1:]-x[:-1])/12
            )
        Y = Y.cumsum(0)
      def f(xs):
        if len(y)==1:
          return y[0]*(xs - x[0])
        I = T.searchsorted(x[1:], xs)
        dx = (x[I+1]-x[I])
        hh = H_poly((xs-x[I])/dx)
        return Y[I] + dx*(
            hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx
            )
      return f
    
    def integ(x, y, xs):
      return integ_func(x,y)(xs)
    
    # Example
    if __name__ == "__main__":
      import matplotlib.pylab as P # for plotting
      x = T.linspace(0, 6, 7)
      y = x.sin()
      xs = T.linspace(0, 6, 101)
      ys = interp(x, y, xs)
      Ys = integ(x, y, xs)
      P.scatter(x, y, label='Samples', color='purple')
      P.plot(xs, ys, label='Interpolated curve')
      P.plot(xs, xs.sin(), '--', label='True Curve')
      P.plot(xs, Ys, label='Spline Integral')
      P.plot(xs, 1-xs.cos(), '--', label='True Integral')
      P.legend()
      P.show()
    

    Resulting image from code example