Search code examples
pythonpytorch

PyTorch: how do I add an arbitrary function between layers?


!!! I am just starting to understand PyTorch !!!

Assume that the model has the following architecture:

(conv1): Conv2d(2, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=256, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)

What should I do to add some MyFunction between conv1 and pool layers, for example?

Here is my current code:

class CNN(Module):
    def __init__(self) -> None:
        super(CNN, self).__init__()
        self.cnn_layer = Sequential(
            Conv2d(in_channels=2, out_channels=6, kernel_size=5),
            # MyFunction here
            ReLU(inplace=True),
            MaxPool2d(kernel_size=2, stride=2),
        )
        self.linear_layers = Sequential(
            Linear(256, 120), Linear(120, 84), Linear(84, 10)
        )

    def forward(self, image):
        image = self.cnn_layer(image)

        image = image.view(-1, 4 * 4 * 16)
        image = self.linear_layers(image)
        return image

Solution

  • Note, a Sequential layer is just a way to bundle multiple feed-forward layers into "one". This means, you dont need to pass your data to each layer explicitly (in contrast what I did below). I rewrote your example without Sequential layers so that you see what happens underneath. Doing so makes it easy to access the layer outputs / inputs and change them according to your needs. Of course you could re-arange your Sequential bundles to make a split where you need to access the x for your "function".

    class CNN(Module):
        def __init__(self) -> None:
            super(CNN, self).__init__()
            self.conv1 = Conv2d(in_channels=2, out_channels=6, kernel_size=5)
            self.relu1 = ReLU(inplace=True)
            self.maxpool1 = MaxPool2d(kernel_size=2, stride=2)
            self.flatten = Flatten()
            self.linear1 = Linear(256, 120)
            self.linear2 = Linear(120, 84)
            self.linear3 = Linear(84, 10)
    
        def forward(self, image):
            x = self.conv1(image)
            x = x * 2 - 123  # arbitrary stuff
            x = self.relu1(x)
            x = self.maxpool(x)
            x = self.flatten(x)  # shorter than your reshaping
            x = linear1(x)
            x = linear2(x)
            x = linear3(x)
            return x