Search code examples
pythonpytorchabstract-class

Should I inherit from both nn.Module and ABC?


Should I create a class that inherits both from torch.nn.Module and ABC? And is it acceptable to call the __init__() function of ABC? (I guess it's ok since the ABC class is just a trivial subclass of object)

If I should use the NotImplemented way, how can I decide when to use which way?

I use the AbstractModel to initialize the config for all the children modules.

import torch.nn as nn
from abc import ABC, abstractmethod

class AbstractModel(nn.Module, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config
    
    @abstractmethod
    def generate(self):
        pass
    
class sub(AbstractMode):
    def __init__(self, config):
        super().__init__(config)

    def generate(self):
        print(self.config)

Solution

  • As an addition to @Nopileos answer you should use the NotImplementedError as well as abc, as it does not allow inheriting class to call super() and use these methods.

    One example could be:

    import abc
    
    import torch
    
    
    class Base(torch.nn.Module, abc.ABC):
        def __init__(self, out_features: int):
            super().__init__()
            self.out_channels = out_channels
            self.module = torch.nn.Sequential(
                torch.nn.LazyLinear(out_features * 2),
                torch.nn.GELU(),
                torch.nn.LazyLinear(out_features),
            )
    
        @abc.abstractmethod
        def forward(self, x: torch.Tensor):
            raise NotImplementedError
    
    
    class ResnetLike(Base):
        def forward(self, x: torch.Tensor):
            # super().forward(x) would raise an error correctly
            return self.module(x) + x
    
    
    class TestNetwork(Base):
        def forward(self, x: torch.Tensor):
            return self.module(x) * 2