Search code examples
pythonpytorchresizeimage-resizingtorchvision

How to resize a 5D PyTorch tensor using TorchVision transforms?


I'm trying to resize a 5D PyTorch tensor within a CNN but I'm having some trouble. I've attached my code and the error I'm getting.

import torchvision.transforms.functional as TF
import numpy as np

a = torch.rand((1, 3, 20, 376, 376))

b = TF.resize(a, (200, 200))

ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [20, 376, 376] and output size of [200, 200]. Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

The 5D tensor has the following dimensions:

(1, 3, 20, 376, 376) which corresponds to (BATCH SIZE, CHANNELS, DEPTH, HEIGHT, WIDTH).

The final 3 dimensions of the tensor refer to dimensions of 3D images being passed through the network. I'd like to resize the final two dimensions, the height and the width. Any guidance would be hugely appreciated. Thank you!

Unfortunately I can't convert the tensors to numpy arrays, resize, and then re-convert them to tensors as I'll lose the gradients needed for gradient descent in training.


Solution

  • One approach using TF.resize is to flatten the batch and depth dimensions, perform the resize, then recover the initial depth dimension:

    >>> a_flat = a.swapaxes(1,2).flatten(0,1)     # (B*D, C,  H,   W)
    >>> c = TF.resize(a_flat , (200, 200))        # (B*D, C, 200, 200)
    >>> b = b.view(*a.shape[:3], *c.shape[-2:])   # (B, D, C, 200, 200)