I want to apply N-d interpolation to an (N+2)-d tensor for N>3.
import torch
import torch.nn.functional as F
x = torch.randn(1, 1, 2, 3, 4, 5, 6, 7)
output_size = (7, 6, 5, 4, 3, 2)
y = F.interpolate(x, size=output_size, mode="linear")
The above code gives the following error:
NotImplementedError: Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got linear)
Note that the first two dimensions are batch size and channels (B, C), and are thus not interpolated, as stated in the docs.
How do I apply N-d linear interpolation for N>3?
N-d linear interpolation is effectively the same as applying 1-D linear interpolation along each interpolated dimension in succession. Wikipedia gives the following example diagram for 2-D (bilinear) interpolation:
Bilinear interpolation |
---|
![]() |
Apply linear interpolation along each dimension. |
Thus, one simple method is:
def interpolate(input, size, scale_factor=None):
assert input.ndim >= 3
if scale_factor is not None:
raise NotImplementedError
output_shape = (*input.shape[:2], *size)
assert len(input.shape) == len(output_shape)
# Apply linear interpolation to each spatial dimension.
for i in range(2, 2 + len(size)):
input_tail = math.prod(input.shape[i + 1 :])
input = F.interpolate(
input.reshape(
input.shape[0], math.prod(input.shape[1:i]), input.shape[i], input_tail
),
size=(output_shape[i], input_tail),
mode="bilinear",
).reshape(*input.shape[:i], output_shape[i], *input.shape[i + 1 :])
return input.reshape(output_shape)
Usage:
x = torch.randn(1, 1, 2, 3, 4, 5, 6, 7)
output_size = (7, 6, 5, 4, 3, 2)
y = interpolate(x, size=output_size)