3

I have seen two different implementations of Multi-Head Attention.

  1. In one of the approaches the queries, keys and values are split into heads before being passed through the linear layers as shown below:

    def split_heads(self, x, batch_size):
        return x.reshape(batch_size, -1, self.heads, self.projection_dim)

    def forward(self, queries, keys, values, mask):
        batch_size = queries.size()[0]

        # split queries keys and values into heads
        queries = self.split_heads(queries, batch_size)
        keys = self.split_heads(keys, batch_size)
        values = self.split_heads(values, batch_size)

        queries = self.queries_linear(queries)
        keys = self.keys_linear(keys)
        values = self.values_linear(values)
        #...more code

  1. The second approach is to split the queries, keys and values into heads after passing them through linear layers:
   def forward(self, queries, keys, values, mask=None):
        
        batch_size = q.size(0)
        
        # perform linear operation and split into h heads
        k = self.keys_linear(keys).view(batch_size, -1, self.heads, self.projection_dim)
        q = self.queries_linear(queries).view(batch_size, -1, self.heads, self.projection_dim)
        v = self.values_linear(values).view(batch_size, -1, self.heads, self.projection_dim)
        #...more code
        

According to the paper Attention Is All You Need, from what I can deduce from the image the queries and keys should be split before being passed through the linear layers, but from most implementation online they are split after. Multi-Head Attention

Are the two approaches similar or is one better than the other?

Kinyugo
  • 429
  • 1
  • 4
  • 11
  • Well what domains were the two different implementations? NLP, image, video, ...? Please give citations and more detail. – smci Jul 30 '20 at 17:50
  • Both implementations are in the field of NLP – Kinyugo Jul 30 '20 at 18:34
  • **Please give citations and more detail.** Post the actual links. What type of NLP? Question-Answering? Topic Classification? Summarization of biomedical journal articles? Chatbot? (closed-domain/open-domain?) something else? character level? word-level? sentence-piece level? etc etc. You really need to give much more detail. – smci Jul 30 '20 at 18:41
  • If I were to follow any implementation, it would be the [official](https://github.com/tensorflow/tensor2tensor) one! That being said, either method may be used and the results compared for some same task. – anurag Jan 01 '21 at 10:13

1 Answers1

0

Yes, I think the paper is quite confusing.
But according to the tutorial of pytorch lightning, you can see it's first linear then split. I think in this way the linear layer will help you to decide how to distribute the values in each head.
https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/05-transformers-and-MH-attention.html

Edit:
I got more proves that is done before, just look at the GPT architecture here: https://upload.wikimedia.org/wikipedia/commons/9/91/Full_GPT_architecture.png You can see clearly it is done before.

TheEngineerProgrammer
  • 1,282
  • 1
  • 4
  • 9