Search code examples
pythonmachine-learningpytorchnlptransformer-model

How does an instance of pytorch's `nn.Linear()` process a tuple of tensors?


In the annotated transformer's implementation of multi-head attention, three tensors (query, key, value) are all passed to a nn.Linear(d_model, d_model):

# some class definition ...
self.linears = clones(nn.Linear(d_model, d_model), 4) # deep-copied list of nn.Linear-modules concatenated via nn.ModuleList
# more code ...
query, key, value = [
  lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
  for lin, x in zip(self.linears, (query, key, value))
]

My question: what happens at lin(x), when an instance of nn.Linear() is called on the tuple (query, key, value)? Is the tuple somehow concatenated to a tensor? If so, how - on which dimension are the tensors concatenated?


Solution

  • self.linears = clones(nn.Linear(d_model, d_model), 4) # deep-copied list of nn.Linear-modules concatenated via nn.ModuleList
    # more code ...
    query, key, value = [
      lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
      for lin, x in zip(self.linears, (query, key, value))
    ]
    

    Actually, the nn.Linear does not process input as a tuple of a Q,K,V. In your code, the result similar like this

    out_Q = self.linears[0](Q)
    out_K = self.linears[1](K)
    out_V = self.linears[2](V)
    

    When you use zip(iterable A, iterable B) So you will get the pairs (A[0], B[0]) (A[1], B[1]) ,... independently

    Or more specific

    query = self.linears[0](query)
    key = self.linears[1](key)
    value = self.linears[2](value)