Search code examples
pythonpython-3.xfunctionclassinstance

How can I make a class object callable in Python?


How can you define in Python3 a class MyClass such that you instantiate it like obj = MyClass(param1, param2) and then use it to compute an operation like res = obj(in1, in2, in3) ?

For instance, with PyTorch you can declare a model as mod = MyResNet50() and then compute its prediction as pred = mod(input).

Below is the code I tried. I declare a method and call it as obj.method().

import numpy as np


class MLP:

    def __init__(self, hidden_units: int, input_size: int):
        self.hidden_units = hidden_units
        self.input_size = input_size
        self.layer1 = np.random.normal(0, 0.01, size=(hidden_units, input_size))
        self.layer2 = np.random.normal(0, 0.01, size=(1, hidden_units))

    def sigmoid(self, z):
        return 1/(1 + np.exp(-z))

    def predict(self, input):
        pred = self.layer1.dot(input)
        pred = self.layer2.dot(pred)
        return self.sigmoid(pred)

my_MLP = MLP(5, 10)
pred = my_MLP.predict(np.random.normal(0, 0.01, 10))

Solution

  • Implement __call__ to react to the class instance being called with ():

    class MyClass:
        def __init__(self, p1, p2):
            self.p1 = p1
            self.p2 = p2
    
        def __call__(self, in1, in2, in3):
            return f'Hello, this is {self.p1} {self.p2}. You rang? {in1} {in2} {in3}'
    

    Example:

    >>> obj = MyClass('Foo', 'Bar')
    >>> res = obj(1, 2, 3)
    >>> res
    'Hello, this is Foo Bar. You rang? 1 2 3'
    

    If your class instance does not have __call__ defined (either itself or what it descends from) it will let you know:

    class MyClass:
        def __init__(self, p1, p2):
            self.p1 = p1
            self.p2 = p2
    
        # no other methods and only descendant of `object`...
    
    >>> MyClass('Foo', 'Bar')()
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    TypeError: 'MyClass' object is not callable