Search code examples
python3dconv-neural-networkpytorchcrop

How can I do the centercrop of 3D volumes inside the network model with pytorch


In keras, there is Cropping3D layer for centercropping tensors of 3D volumnes inside the neural network. However, I failed to find out anything similar in pytorch, though they have torchvision.transforms.CenterCrop(size) for 2D images.

How can I do the cropping inside the network? Otherwise I need to do it in preprocessing which is the last thing I want to do for specific reasons.

Do I need to write a custom layer like slicing the input tensors along each axices? Hope to get some inspiration for this


Solution

  • In PyTorch you don't necessarily need to write layers for everything, often you can just do what you want directly during the forward pass. The basic rules you need to keep in mind when operating on torch tensors for which you will need to compute gradients are

    1. Don't convert torch tensors to other types for computation (e.g. use torch.sum instead of converting to numpy and using numpy.sum).
    2. Don't perform in-place operations (e.g. changing one element of a tensor or using inplace operators, so use x = x + ... instead of x += ...).

    That said, you can just use slicing, maybe it would look something like this

    def forward(self, x):
        ...
        x = self.conv3(x)
        x = x[:, :, 5:20, 5:20]    # crop out part of the feature map
        x = self.relu3(x)
        ...