I'm learning multi-head attention with this article. As the writer claimed, the structure of MHA (by the original paper) is as follows:
But the MultiHeadAttention
layer of Tensorflow seems to be more flexible:
- It does not require
key_dim * num_heads = embed_dim
. Like:
layer = tf.keras.layers.MultiHeadAttention(num_heads = 2, key_dim = 4)
x = tf.keras.Input(shape=[3, 5])
layer(x, x)
# no error
Does the depth of the weight matrix in tf.MHA
layer set to key_dim * num_heads
regardless of embed_dim
? So that Q/K/V can still be properly split by num_heads
.
- However, the output depth of tf.MHA layer is (by default) guaranteed to be
embed_dim
. So there is a final dense layer withembed_dim
nodes to ensure the dimension?