Search code examples
pythonnlppytorchspacyspacy-transformers

How are token vectors calculated in spacy-pytorch-transformers


I am currently working with the spacy-pytorch-transformer package to experiment with the respective embeddings.
When reading the introductionary article (essentially the GitHub README), my understanding was that the token-level embeddings are the mean over the embeddings of all corresponding word pieces, i.e. embed(complex) would be the same as 1/2 * embed(comp#) * embed(#lex).

According to the BERT paper, this should simply utilize the last_hidden_state property of the network, but my MCVE below shows that this is not the same for Spacy 2.1.8 and spacy-pytorch-transformers 0.4.0, for at least BERT and RoBERTa (have not verified it for more models):

import spacy
import numpy as np
nlp = spacy.load("en_pytt_robertabase_lg")  # either this or the BERT model
test = "This is a test"  # Note that all tokens are directly aligned, so no mean has to be calculated.
doc = nlp(test)
# doc[0].vector and doc.tensor[0] are equal, so the results are equivalent.
print(np.allclose(doc[0].vector, doc._.pytt_last_hidden_state[1, :]))
# returns False

The offset of 1 for the hidden states is due to the <CLS> token as the first input, which corresponds to the sentence classification task; I even checked with any available other token for my sentence (which has no token alignment problems according to doc._.pytt_alignment), so there is no way I missed something here.

According to the source code, the corresponding hook overwrites simply to return the corresponding row in the tensor, so I do not see any transformation here. Is there something obvious that I am missing here, or is this deviating from the expected behavior?


Solution

  • It seems that there is a more elaborate weighting scheme behind this, which also accounts for the [CLS] and [SEP] token outputs in each sequence.

    This has also been confirmed by an issue post from the spaCy developers.

    Unfortunately, it seems that this part of the code has since moved with the renaming to spacy-transformers.