Search code examples
pythonnumpyfortraninterpolation

Achieving numpy like fast interpolation in Fortran


I have a numerical routine that I need to run to solve a certain equation, which contains a few nested four loops. I initially wrote this routine into Python, using numba.jit to achieve an acceptable performance. For large system sizes however, this method becomes quite slow, so I have been rewriting the routine into Fortran hoping to achieve a speed-up. However I have found that my Fortran version is much slower than the first version in Python, by a factor of 2-3.

I believe the bottleneck is a linear interpolation function that is called at each innermost loop. In the Python implementation I use numpy.interp, which seems to be pretty fast when combined with numba.jit. In Fortran I wrote my own interpolation function, which reads,

  complex*16 function interp(x_dat, y_dat, N_dat, x)
    implicit none
    integer, intent(in) :: N_dat
    real*8, dimension(N_dat), intent(in) :: x_dat
    complex*16, dimension(N_dat), intent(in) :: y_dat
    real*8, intent(in) :: x

    complex*16 :: y, y1, y2
    integer :: i, i1, i2, im

    if(x <= x_dat(1)) then
      y = y_dat(1)
    else if(x >= x_dat(N_dat)) then
      y = y_dat(N_dat)
    else
      im = MINLOC(DABS(x_dat - x), DIM=1)
      if(x_dat(im) >=x ) then
        i1 = im
        i2 = im - 1
      else
        i1 = im + 1
        i2 = im
      end if

      y1 = y_dat(i1)
      y2 = y_dat(i2)

      y = y1 + (x-x_dat(i1))*(y2 - y1)/(x_dat(i2) - x_dat(i1))
    end if
    interp = y
    return
  end function interp

Note that I need to interpolate complex data. If my diagnostics are correct this function is much slower than numpy.interp, which, since the interpolation has to be called in each loop, greatly reduces the speed of the whole program.

Does anyone know if there is a way to achieve Numpy like interpolation speeds in Fortran? Or if my interpolation function as shown above here is somehow horribly inefficient? I don't have much experience coding Fortran yet.

Thanks!


Solution

  • At a guess (and see @IanBush's comments if you want to enable us to do better than guessing), it's the line

    im = MINLOC(DABS(x_dat - x), DIM=1)
    

    which is taking all your time, as this line is O(N) in the size of x_dat, where everything else is O(1).

    If x_dat is linearly spaced then you can replace this line with

    im = 1 + nint((N_dat-1)*(x-x_dat(1))/(x_dat(N_dat)-x_dat(1)))
    

    or better yet, skip im entirely and calculate i1 and i2 as

    i1 = 1 + floor((N_dat-1)*(x-x_dat(1))/(x_dat(N_dat)-x_dat(1)))
    i2 = i1 + 1
    

    If x_dat is not linearly spaced, but has other useful properties, then you want to use these properties to calculate im if possible.