Search code examples
pythoncall

How do instances of Andrej Karpathy's BigramLanguageModel run as functions with no `__call__` function?


Andrej Karpathy's nanoGPT defines BigramLanguageModel as follows.

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        ...

    def forward(self, x):
        ...

It then runs the following.

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)

The call to m() runs the forward() method as if there were a __call__ function that called forward(). But no __call__ function is visible. How does this work?

Thanks.


Solution

  • It's because the base class torch.nn.Module defines __call__. See the implementation.

    The following example shows a similar technique.

    class X:
        def test(self):
            print('test')
        __call__ = test
    
    X()()
    # This will output 'test'.
    

    As a side note, you incorrectly scrapped the __init__() of the BigramLanguageModel in that Notebook.