Search code examples
pythonscipyinterpolationlinear-interpolation

Interpolate array along first axis with scipy


I'm struggling hard to perform a simple linear interpolation of a datacube of size Nx*Ny*Nz into a new one that would keep the other two dimensions constant i.e. the resulting output would be Nxnew*Ny*Nz. It seems that RegularGridInterpolator from scipy seems to be the way to go, although, it is not intuitive to me how to generate the input.

from scipy.interpolate import RegularGridInterpolator
import numpy as np

x = np.linspace(1,4,11)
y = np.linspace(4,7,22)
z = np.linspace(7,9,33)
V = np.zeros((11,22,33))
for i in range(11):
    for j in range(22):
        for k in range(33):
            V[i,j,k] = 100*x[i] + 10*y[j] + z[k]
fn = RegularGridInterpolator((x,y,z), V)
pts = np.array([[[[2,6,8],[3,5,7]], [[2,6,8],[3,5,7]]]])
out = fn(pts)
print(out, out.shape)

In this mwe, I'd like to use new points xnew = np.linspace(2,3,50), while keeping y and z the same, so the resulting array becomes of shape (50,22,33). Also, how would one generalize this to an interpolation along 1 dimension for an n-dimensional array, while keeping the rest of the coordinates the same?


Solution

  • As suggested in the comment, you can replace the triply-nested loop with a call to np.meshgrid to make the code more readable and efficient.

    x, y, z = np.meshgrid(x, y, z, indexing='ij')
    V = 100*x + 10*y + z
    

    As for generating the input to your fn object, note that its __call__ method is expecting an input of shape (..., ndim). In this case, ... is your desired shape (50, 22, 33), and ndim is the number of coordinates (3 for x, y, and z). We can use meshgrid to generate the coordinates in three separate arrays, but to form the input to fn, we need to join them in such a way that the axis corresponding with coordinates comes last. There are several ways to do this, but the easiest is to use np.stack.

    from scipy.interpolate import RegularGridInterpolator
    import numpy as np
    
    x0 = np.linspace(1, 4, 11)
    y0 = np.linspace(4, 7, 22)
    z0 = np.linspace(7, 9, 33)
    
    x, y, z = np.meshgrid(x0, y0, z0, indexing='ij')
    V = 100*x + 10*y + z
    
    fn = RegularGridInterpolator((x0, y0, z0), V)
    
    xnew = np.linspace(2, 3, 50)
    x, y, z = np.meshgrid(xnew, y0, z0, indexing='ij')
    
    xi = np.stack((x, y, z), axis=-1)
    # or
    # xi = np.moveaxis(np.asarray([x, y, z]), 0, -1)
    # or
    # xi = np.concatenate((x[..., np.newaxis, ], y[..., np.newaxis], z[..., np.newaxis]), axis=-1)
    
    out = fn(xi)
    print(out.shape)
    # (50, 22, 33)
    

    The meaning of "n" in your question about generalizing to "n-dimensional arrays" could be ambiguous. Presumably "n" represent the number of coordinates, that is, the dimensionality of your V. In that case, generalization to situations with any number of coordinates is trivial: jiust treat those coordinates (e.g. t, u, v, w) as we have treated x, y, and z in this example.