Search code examples
deep-learningpytorchconv-neural-networktorchvision

Problem with nested network on pytorch ."TypeError: forward() missing 1 required positional argument: 'x'"


I attempt to create an architecture consisting of one convolutional filter and one layer of three convolutional filters. I first build the inner layer with the name "MysmallNet(nn.module)", and then I build "MybigNet" calling the small network. This is my code.

#In[]
class MysmallNet(nn.Module):
    def __init__(self):
        super(MysmallNet, self).__init__()
        # TODO Task 3: Design Your Network
        self.Convlayer_1 = nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = 3, stride = 1,padding=1)
        self.Convlayer_2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1, padding=1)
        self.Convlayer_3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1, padding=1)
        
    def forward(self, x):
        # TODO Task 3: Design Your Network
        residual1 = x
        x = self.Convlayer_1(x)
        x = self.Convlayer_2(x)
        x = self.Convlayer_3(x)
        return x

MysmallNetV2= MysmallNet()

class MybigNet(nn.Module):
    def __init__(self):
        super(MybigNet, self).__init__()

        self.Convlayer_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,stride=1,padding=1)
        self.smallNet= MysmallNetV2()

    def forward(self, x):
        x = self.Convlayer_1(x)
        x = self.smallNet(x)
        return x

modelBig = MybigNet()

I have the issue when I save my model as "modelBig". The displayed error is :

TypeError: forward() missing 1 required positional argument: 'x'

Solution

  • Your definition of big net is wrong, it should be:

    class MybigNet(nn.Module):
        def __init__(self):
            super(MybigNet, self).__init__()
    
            self.Convlayer_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,stride=1,padding=1)
            self.smallNet= MysmallNet()
    
        def forward(self, x):
            x = self.Convlayer_1(x)
            x = self.smallNet(x)
            return x
    

    This should solve the issue.