Search code examples
pytorchvisiontransformer-model

Why is the timm visual transformer position embedding initializing to zeros?


I'm looking at the timm implementation of visual transformers and for the positional embedding, he is initializing his position embedding with zeros as follows:

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

See here: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L309

I'm not sure how this actually embeds anything about the position when it is later added to the patch?

x = x + self.pos_embed

Any feedback is appreciated.


Solution

  • The positional embedding is a parameter that gets included in the computational graph and gets updated during training. So, it doesn't matter if you initialize with zeros; they are learned during training.