I have n
-vectors which need to be influenced by each other and output n
vectors with same dimensionality d
. I believe this is what torch.nn.MultiheadAttention
does. But the forward function expects query, key and value as inputs. According to this blog, I need to initialize a random weight matrix of shape (d x d)
for each of q
, k
and v
and multiply each of my vectors with these weight matrices and get 3 (n x d)
matrices. Now are the q
, k
and v
expected by torch.nn.MultiheadAttention
just these three matrices or do I have it mistaken?
Asked
Active
Viewed 7,663 times
7

angryweasel
- 316
- 2
- 10
1 Answers
12
When you want to use self attention, just pass your input vector into torch.nn.MultiheadAttention
for the query, key and value.
attention = torch.nn.MultiheadAttention(<input-size>, <num-heads>)
x, _ = attention(x, x, x)
The pytorch class returns the output states (same shape as input) and the weights used in the attention process.

Theodor Peifer
- 3,097
- 4
- 17
- 30
-
Are the inputs `x` supposed to be sequences of token ids or embeddings? – Evan Zamir May 07 '21 at 17:14
-
x is the embeddings – Mehrdad Dec 14 '21 at 20:37
-
1this also assumes the k,q,v-dimensions are the same – Kaare May 13 '22 at 10:38
-
Slightly cleaner: `x = attention(x, x, x, need_weights=False)` – Rocco Fortuna Feb 08 '23 at 16:58