I am new to Neural Networks and I have the following N.N. in a python notebook below that takes in images as the input. I am trying to get it to run but I keep getting the following error: TypeError: linear(): argument 'input' (position 1) must be Tensor, not Flatten
. (Please let me know if any additional information is needed).
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(600, 120)
self.fc2 = nn.Linear(120, 2)
self.fc3 = nn.Linear(2, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = nn.Flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Network()
I have looked at other sources where others have gotten a similar error but not the exact error that I got. At first I tried to see if making my neural network similar to how other had it but that did not seem to work.
You are using the wrong Flatten!
There are two options:
A "flatten layer": https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html#torch.nn.Flatten
The flatten function itself: https://pytorch.org/docs/stable/generated/torch.flatten.html
In your network you are using the first, but you should be using the second! You are creating a flatten layer instead of applying the flatten transformation to your inputs.
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(600, 120)
self.fc2 = nn.Linear(120, 2)
self.fc3 = nn.Linear(2, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Or
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(600, 120)
self.fc2 = nn.Linear(120, 2)
self.fc3 = nn.Linear(2, 10)
self.flatten = nn.Flatten(1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x