Search code examples
pythonclassobjectsubclasssubclassing

Where is this super function looking for __init__()?


I'm looking at this python class and I'm trying to figure out why the super function has any arguments. If my understanding is correct, super() with no arguments, in this case, will do the same job. Am I correct?

Here's the code

class Net(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Net, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, 1)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        output = self.sigmoid(output)
        return output

Solution

  • In Python 2, the arguments were not optional. In Python 3, there is special compiler magic to determine what arguments to use is they are omitted. In the context of Net.__init__, super() is equivalent to super(Net, self).