Torch claim that EmbeddingBag with mode="sum" is equivalent to Embedding followed by torch.sum(dim=1), but how can I implement it in detail? Let's say we have "EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)", how can we replace the "nn.EmbeddingBag" by "nn.Embeeding" and "torch.sum" equivalently? Many thanks
Asked
Active
Viewed 351 times
0
-
Does this answer your question? [How embedding\_bag exactly works in PyTorch](https://stackoverflow.com/questions/62052734/how-embedding-bag-exactly-works-in-pytorch) – Ivan Sep 30 '21 at 17:40
-
Hi Ivan, thank you for the reply. Actually, I had checked the answer before asking this question. I don't know how to implement it in a similar way for nn.EmbeddingBag instead of nn.functional.EmbeddingBag. It seems like that there is no number function "sum" for nn.Embedding. – user2584234 Oct 01 '21 at 13:09
1 Answers
0
Consider the following example where all four implementations yield the same result:
-
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') >>> embedding_sum(input, torch.zeros(1).long())
-
>>> F.embedding_bag(input, embedding_sum.weight, torch.zeros(1).long(), mode='sum')
-
>>> embedding = nn.Embedding(10, 3) >>> embedding.weight = embedding_sum.weight >>> embedding(input).sum(0)
-
>>> F.embedding(input, embedding_sum.weight).sum(0)

Ivan
- 34,531
- 8
- 55
- 100