Should I create a class that inherits both from torch.nn.Module
and ABC
? And is it acceptable to call the __init__()
function of ABC
? (I guess it's ok since the ABC
class is just a trivial subclass of object
)
If I should use the NotImplemented
way, how can I decide when to use which way?
I use the AbstractModel
to initialize the config for all the children modules.
import torch.nn as nn
from abc import ABC, abstractmethod
class AbstractModel(nn.Module, ABC):
def __init__(self, config):
super().__init__()
self.config = config
@abstractmethod
def generate(self):
pass
class sub(AbstractMode):
def __init__(self, config):
super().__init__(config)
def generate(self):
print(self.config)
As an addition to @Nopileos answer you should use the NotImplementedError
as well as abc
, as it does not allow inheriting class to call super()
and use these methods.
One example could be:
import abc
import torch
class Base(torch.nn.Module, abc.ABC):
def __init__(self, out_features: int):
super().__init__()
self.out_channels = out_channels
self.module = torch.nn.Sequential(
torch.nn.LazyLinear(out_features * 2),
torch.nn.GELU(),
torch.nn.LazyLinear(out_features),
)
@abc.abstractmethod
def forward(self, x: torch.Tensor):
raise NotImplementedError
class ResnetLike(Base):
def forward(self, x: torch.Tensor):
# super().forward(x) would raise an error correctly
return self.module(x) + x
class TestNetwork(Base):
def forward(self, x: torch.Tensor):
return self.module(x) * 2