!!! I am just starting to understand PyTorch !!!
Assume that the model has the following architecture:
(conv1): Conv2d(2, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=256, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
What should I do to add some MyFunction between conv1 and pool layers, for example?
Here is my current code:
class CNN(Module):
def __init__(self) -> None:
super(CNN, self).__init__()
self.cnn_layer = Sequential(
Conv2d(in_channels=2, out_channels=6, kernel_size=5),
# MyFunction here
ReLU(inplace=True),
MaxPool2d(kernel_size=2, stride=2),
)
self.linear_layers = Sequential(
Linear(256, 120), Linear(120, 84), Linear(84, 10)
)
def forward(self, image):
image = self.cnn_layer(image)
image = image.view(-1, 4 * 4 * 16)
image = self.linear_layers(image)
return image
Note, a Sequential
layer is just a way to bundle multiple feed-forward layers into "one". This means, you dont need to pass your data to each layer explicitly (in contrast what I did below). I rewrote your example without Sequential
layers so that you see what happens underneath. Doing so makes it easy to access the layer outputs / inputs and change them according to your needs. Of course you could re-arange your Sequential
bundles to make a split where you need to access the x
for your "function".
class CNN(Module):
def __init__(self) -> None:
super(CNN, self).__init__()
self.conv1 = Conv2d(in_channels=2, out_channels=6, kernel_size=5)
self.relu1 = ReLU(inplace=True)
self.maxpool1 = MaxPool2d(kernel_size=2, stride=2)
self.flatten = Flatten()
self.linear1 = Linear(256, 120)
self.linear2 = Linear(120, 84)
self.linear3 = Linear(84, 10)
def forward(self, image):
x = self.conv1(image)
x = x * 2 - 123 # arbitrary stuff
x = self.relu1(x)
x = self.maxpool(x)
x = self.flatten(x) # shorter than your reshaping
x = linear1(x)
x = linear2(x)
x = linear3(x)
return x