In keras
, there is Cropping3D
layer for centercropping tensors of 3D volumnes inside the neural network. However, I failed to find out anything similar in pytorch, though they have torchvision.transforms.CenterCrop(size)
for 2D images.
How can I do the cropping inside the network? Otherwise I need to do it in preprocessing which is the last thing I want to do for specific reasons.
Do I need to write a custom layer like slicing the input tensors along each axices? Hope to get some inspiration for this
In PyTorch you don't necessarily need to write layers for everything, often you can just do what you want directly during the forward pass. The basic rules you need to keep in mind when operating on torch tensors for which you will need to compute gradients are
torch.sum
instead of converting to numpy and using numpy.sum
).x = x + ...
instead of x += ...
).That said, you can just use slicing, maybe it would look something like this
def forward(self, x):
...
x = self.conv3(x)
x = x[:, :, 5:20, 5:20] # crop out part of the feature map
x = self.relu3(x)
...