0

The AttentionQKV layer implemented by Trax is as the following: AttentionQKV

def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps (q, k, v, mask) to (activations, mask).
  See `Attention` above for further context/details.
  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )

In particular, what is the purpose of the three parallel dense layers? The input to this layer is q, k, v, mask. Why the q, k, v are put through a dense layer?

Charles Ju
  • 1,095
  • 1
  • 9
  • 28

1 Answers1

0

This code snippet is an implementation of the equation on the top of page 5 of the Attention is all you need paper that introduced the Transformer models in 2017. The computation is illustrated in Figure 2 of the paper:

enter image description here

The hidden states get projection into h attention heads which do the scaled dot-product attention in parallel. The projection can be interpreted as extraction of information that is relevant for the head. Each head then does the probabilistic retrieval based on different (learned) criteria.

Jindřich
  • 10,270
  • 2
  • 23
  • 44
  • So the Q, K, V in the picture above actually means the hidden states? Somehow I thought the Q, K, V is the result of the hidden states multiply by the Query, Key and Value matrix. – Charles Ju Oct 01 '20 at 15:16
  • Yes, in the self-attention they are all three the same. In the encoder-decoder attention, Q comes from the decoder, V and K are the encoder states. – Jindřich Oct 01 '20 at 15:37