I'm working on replicating the SHAP package algorithm - an explainability algorithm for machine learning. I've been reading through the author's code, and I've come across a pattern I've never seen before.
The author has created a superclass called Explainer
, which is a common interface for all the different model specific implementations of the algorithm. The Explainer
's __init__
method accepts a string for the algorithm type and switches itself to the corresponding subclass if called directly. It does this using multiple versions of the following pattern:
if algorithm == "exact":
self.__class__ = explainers.Exact
explainers.Exact.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
I understand that this code sets the superclass to one of its subclasses and initialises the subclass by passing itself to __init__
. But why would you do this?
This is a non-standard and awkward way of implementing the Abstract Factory design pattern. The idea is that, although the base class contains state and functionality that are useful for implementing derived classes, it should not be instantiated directly. The full code contains logic that checks whether the base class __init__
is being called "directly" or via super
; in the former case, it checks a parameter and chooses an appropriate derived class. (That derived class, of course, will end up calling back to this __init__
, but this time super
is used, so there is no unbounded recursion.)
To clarify, although this is not standard, it does work:
class Base:
def __init__(self, *, value=None, kind=None):
if self.__class__ is Base:
if kind == 'derived':
self.__class__ = Derived
Derived.__init__(self, value)
else:
raise ValueError("invalid 'kind'; cannot create Base instances explicitly")
class Derived(Base):
def __init__(self, value):
super().__init__()
self.value = value
def method(self):
return 'derived method not defined in base'
Testing it:
>>> Base()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 8, in __init__
ValueError: invalid 'kind'; cannot create Base instances explicitly
>>> Base(value=1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 8, in __init__
ValueError: invalid 'kind'; cannot create Base instances explicitly
>>> Base(value=1, kind='derived')
<__main__.Derived object at 0x7f94fe025790>
>>> Base(value=1, kind='derived').method()
'derived method not defined in base'
>>> Base(value=1, kind='derived').value
1
>>> Derived(2)
<__main__.Derived object at 0x7f94fcc2aa00>
>>> Derived(2).method()
'derived method not defined in base'
>>> Derived(2).value
2
Setting the __class__
attribute allows the factory-created Derived
instance to access the derived method
, and calling __init__
causes it to have a per-instance value
attribute. In fact, we could do those steps in either order, because the Derived __init__
is invoked explicitly rather than via method lookup. Alternatively, it would work (although it would look strange) to call self.__init__(value)
, but only after changing the __class__
.
A more Pythonic way to implement this is to use the standard library abc
functionality to mark the base class as "abstract", and use a named method as a factory. For example, decorating the base class __init__
with @abstractmethod
will prevent it from being instantiated directly, while forcing derived classes to implement __init__
. When they do, they will call super().__init__
, which will work without error. For the factory, we can use a method decorated with @staticmethod
in the base class (or just an ordinary function; but using @staticmethod
effectively "namespaces" the factory). It can, for example, use a string name to choose a derived class, and instantiate it.
A minimal example:
from abc import ABC, abstractmethod
class Base(ABC):
@abstractmethod
def __init__(self):
pass
@staticmethod
def create(kind):
# TODO: add more derived classes to the mapping
return {'derived': Derived}[kind]()
class Derived(Base):
def __init__(self):
super().__init__()
# TODO: implement additional derived classes
Testing it:
>>> Base()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: Can't instantiate abstract class Base with abstract methods __init__
>>> Derived()
<__main__.Derived object at 0x7f94fe025310>
>>> Base.create('derived')
<__main__.Derived object at 0x7f94fe025910>