7

I have n-vectors which need to be influenced by each other and output n vectors with same dimensionality d. I believe this is what torch.nn.MultiheadAttention does. But the forward function expects query, key and value as inputs. According to this blog, I need to initialize a random weight matrix of shape (d x d) for each of q, k and v and multiply each of my vectors with these weight matrices and get 3 (n x d) matrices. Now are the q, k and v expected by torch.nn.MultiheadAttention just these three matrices or do I have it mistaken?

angryweasel
  • 316
  • 2
  • 10

1 Answers1

12

When you want to use self attention, just pass your input vector into torch.nn.MultiheadAttention for the query, key and value.


attention  = torch.nn.MultiheadAttention(<input-size>, <num-heads>)

x, _ = attention(x, x, x)

The pytorch class returns the output states (same shape as input) and the weights used in the attention process.

Theodor Peifer
  • 3,097
  • 4
  • 17
  • 30