0

Torch claim that EmbeddingBag with mode="sum" is equivalent to Embedding followed by torch.sum(dim=1), but how can I implement it in detail? Let's say we have "EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)", how can we replace the "nn.EmbeddingBag" by "nn.Embeeding" and "torch.sum" equivalently? Many thanks

  • Does this answer your question? [How embedding\_bag exactly works in PyTorch](https://stackoverflow.com/questions/62052734/how-embedding-bag-exactly-works-in-pytorch) – Ivan Sep 30 '21 at 17:40
  • Hi Ivan, thank you for the reply. Actually, I had checked the answer before asking this question. I don't know how to implement it in a similar way for nn.EmbeddingBag instead of nn.functional.EmbeddingBag. It seems like that there is no number function "sum" for nn.Embedding. – user2584234 Oct 01 '21 at 13:09

1 Answers1

0

Consider the following example where all four implementations yield the same result:

  • nn.EmbeddingBag:

    >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
    >>> embedding_sum(input, torch.zeros(1).long())
    
  • nn.functional.embedding_bag:

    >>> F.embedding_bag(input, embedding_sum.weight, torch.zeros(1).long(), mode='sum')
    
  • nn.Embedding:

    >>> embedding = nn.Embedding(10, 3)
    >>> embedding.weight = embedding_sum.weight
    >>> embedding(input).sum(0)
    
  • nn.functional.embedding:

    >>> F.embedding(input, embedding_sum.weight).sum(0)
    
Ivan
  • 34,531
  • 8
  • 55
  • 100