Search code examples
pythonoopdeep-learningpytorch

Reason for wrapping simple functions inside of classes (PyTorch)


What is the reason for wrapping simple functions such as torch.cat() (or layers such as MaxPool2d) inside of a class like this:

class Concat(nn.Module):
    def __init__(self, dimension=1):
        super(Concat, self).__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)

class MP(nn.Module):
    def __init__(self, k=2):
        super(MP, self).__init__()
        self.m = nn.MaxPool2d(kernel_size=k, stride=k)

    def forward(self, x):
        return self.m(x)

Solution

  • The biggest reason should be they would be registered to the Model (the model can have reference to them), plus pytorch user (like me :)) heavily use pytorch hooks to interfere with the model, thus it would be better to be able to attach some hooks if need (for debug, modify model behaviour with changing source code, etc.)