Search code examples
pytorchtransformer-modelvision-transformer

Positional encoding for VIsion transformer


why the positional encoding is (1,patch,emb) size, it should be (batch_size,patch,emb) in general even in the pytorch github code https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py they are defining
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT

can anyone help me, what should I use as pos_encoding in my code

self.pos_embedding = nn.Parameter(torch.empty(batch_size, seq_length, hidden_dim).normal_(std=0.02))

is it correct?


Solution

  • Because you dont know the batch_size when initializing self.pos_embedding, so you should init this tensor as:

    self.pos_embedding = nn.Parameter(
        torch.empty(1, num_patches + 1, hidden_dim).normal_(std=0.02)
    ) 
    # (dont forget about the cls token)
    

    PyTorch will take care of the tensors broadcasting in forward pass:

    x = x + self.pos_embedding
    # (batch_size, num_patches + 1, embedding_dim) + (1, num_patches + 1, embedding_dim) is ok
    

    But it won't work with cls token. You should expand this tensor in forward:

    cls_token = self.cls_token.expand(
        batch_size, -1, -1
    )