Andrej Karpathy's nanoGPT defines BigramLanguageModel
as follows.
class BigramLanguageModel(nn.Module):
def __init__(self):
super().__init__()
...
def forward(self, x):
...
It then runs the following.
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
The call to m()
runs the forward()
method as if there were a __call__
function that called forward()
. But no __call__
function is visible. How does this work?
Thanks.
It's because the base class torch.nn.Module
defines __call__
. See the implementation.
The following example shows a similar technique.
class X:
def test(self):
print('test')
__call__ = test
X()()
# This will output 'test'.
As a side note, you incorrectly scrapped the __init__()
of the BigramLanguageModel in that Notebook.