Search code examples
pytorchembedding

Equivalent of torch EmbeddingBag


Torch claim that EmbeddingBag with mode="sum" is equivalent to Embedding followed by torch.sum(dim=1), but how can I implement it in detail? Let's say we have "EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)", how can we replace the "nn.EmbeddingBag" by "nn.Embeeding" and "torch.sum" equivalently? Many thanks


Solution

  • Consider the following example where all four implementations yield the same result:

    • nn.EmbeddingBag:

      >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
      >>> embedding_sum(input, torch.zeros(1).long())
      
    • nn.functional.embedding_bag:

      >>> F.embedding_bag(input, embedding_sum.weight, torch.zeros(1).long(), mode='sum')
      
    • nn.Embedding:

      >>> embedding = nn.Embedding(10, 3)
      >>> embedding.weight = embedding_sum.weight
      >>> embedding(input).sum(0)
      
    • nn.functional.embedding:

      >>> F.embedding(input, embedding_sum.weight).sum(0)