25

From multiple searches and pytorch documentation itself I could figure out that inside embedding layer there is a lookup table where the embedding vectors are stored. What I am not able to understand:

  1. what exactly happens during training in this layer?
  2. What are the weights and how the gradients of those weights are computed?
  3. My intuition is that at least there should be a function with some parameters that produces the keys for the lookup table. If so, then what is that function?

Any help in this will be appreciated. Thanks.

Rituraj Kaushik
  • 353
  • 3
  • 5

2 Answers2

33

That is a really good question! The embedding layer of PyTorch (same goes for Tensorflow) serves as a lookup table just to retrieve the embeddings for each of the inputs, which are indices. Consider the following case, you have a sentence where each word is tokenized. Therefore, each word in your sentence is represented with a unique integer (index). In case the list of indices (words) is [1, 5, 9], and you want to encode each of the words with a 50 dimensional vector (embedding), you can do the following:

# The list of tokens
tokens = torch.tensor([0,5,9], dtype=torch.long)
# Define an embedding layer, where you know upfront that in total you
# have 10 distinct words, and you want each word to be encoded with
# a 50 dimensional vector
embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=50)
# Obtain the embeddings for each of the words in the sentence
embedded_words = embedding(tokens)

Now, to answer your questions:

  1. During the forward pass, the values for each of the tokens in your sentence are going to be obtained in a similar way as the Numpy's indexing works. Because in the backend, this is a differentiable operation, during the backward pass (training), Pytorch is going to compute the gradients for each of the embeddings and readjust them accordingly.

  2. The weights are the embeddings themselves. The word embedding matrix is actually a weight matrix that will be learned during training.

  3. There is no actual function per se. As we defined above, the sentence is already tokenized (each word is represented with a unique integer), and we can just obtain the embeddings for each of the tokens in the sentence.

Finally, as I mentioned the example with the indexing many times, let us try it out.

# Let us assume that we have a pre-trained embedding matrix
pretrained_embeddings = torch.rand(10, 50)
# We can initialize our embedding module from the embedding matrix
embedding = torch.nn.Embedding.from_pretrained(pretrained_embeddings)
# Some tokens
tokens = torch.tensor([1,5,9], dtype=torch.long)

# Token embeddings from the lookup table
lookup_embeddings = embedding(tokens)
# Token embeddings obtained with indexing
indexing_embeddings = pretrained_embeddings[tokens]
# Voila! They are the same
np.testing.assert_array_equal(lookup_embeddings.numpy(), indexing_embeddings.numpy())
gorjan
  • 5,405
  • 2
  • 20
  • 40
  • 4
    So this is exactly the same as a one-hot encoding followed by a linear layer? – user2653663 Dec 14 '20 at 13:59
  • 3
    Exactly. I plan on writing a blog post when I've the time these days and I will update the answer with the link. – gorjan Dec 14 '20 at 14:10
  • 1
    In your description, you said `In case the list of indices (words) is [1, 5, 9]`, but your code says `tokens = torch.tensor([0,5,9],`. Why the change from `[1,5,9]` to `[0,5,9]`? – Richie Thomas Jun 15 '21 at 15:15
  • 1
    Because when you don't double check what you write, you make typos :) Changed now :) – gorjan Jun 15 '21 at 15:59
  • This is a really good explanation. how does backpropagation update the embedding matrix since different indices are picked up for different inputs? – Sandeep Pandey Aug 22 '23 at 09:28
3

nn.Embedding layer can serve as a lookup table. This means if you have a dictionary of n elements you can call each element by id if you create the embedding.

In this case the size of the dictionary would be num_embeddings and the embedding_dim would be 1.

You don't have anything to learn in this scenario. You just indexed elements of a dict, or you encoded them, you may say. So forward pass analysis in this case is not needed.

You may have used this if you used word embeddings like Word2vec.

On the other side you may use embedding layers for categorical variables (features in general case). In there you will set the embedding dimension embedding_dim to the number of categories you may have.

In that case you start with randomly initialized embedding layer and you learn the categories (features) in forward.

prosti
  • 42,291
  • 14
  • 186
  • 151