1

Trying to use the AdditiveAttention layer in Keras. On manual implementation of the layer from tensorflow tutorial https://www.tensorflow.org/tutorials/text/nmt_with_attention

import tensorflow as tf 

class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    query_with_time_axis = tf.expand_dims(query, 1)
    score = self.V(tf.nn.tanh(
        self.W1(query_with_time_axis) + self.W2(values)))
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)
    return context_vector, attention_weights

The shape of the context_vector is (batch_size, units)

Whereas using the same AdditiveAttention layer from keras built-in

from tensorflow.keras.layers import AdditiveAttention

the shape of the context_vector = [batch_size, Tq, dim]

Any suggestions on what is causing this OP shape difference will be useful.

data_person
  • 4,194
  • 7
  • 40
  • 75

1 Answers1

1

Both implementations are mutually similar except for some variation. The implementation of BahdanauAttention in that tutorial is a kinda simplified and adapted version and uses some linear transformation. The return shape of context_vector that you're wondering is nothing but the issue of input data shape. Here is some demonstration, let's see the tutorial implementation:

class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V  = tf.keras.layers.Dense(1)

  def call(self, query, values):
    query_with_time_axis = tf.expand_dims(query, 1)
    score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
    attention_weights = tf.nn.softmax(score, axis=1)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)
    return context_vector, attention_weights

Now, we pass some input to it, 3D and 2D.

attention_layer = BahdanauAttention(10)

y = tf.random.uniform((2, 60, 512))  
out, attn = attention_layer(y, y)
out.shape , attn.shape
(TensorShape([2, 60, 512]), TensorShape([2, 2, 60, 1]))

y = tf.random.uniform((2, 512))  
out, attn = attention_layer(y, y)
out.shape , attn.shape
(TensorShape([2, 512]), TensorShape([2, 2, 1]))

Now, passing the same inputs to the built-in AdditiveAttention and see what we'll get

buit_attn = tf.keras.layers.AdditiveAttention()

y = tf.random.uniform((2, 60, 512))  
out, attn = buit_attn([y, y], return_attention_scores=True)
out.shape , attn.shape
(TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))

y = tf.random.uniform((2, 512))  
out, attn = buit_attn([y, y], return_attention_scores=True)
out.shape , attn.shape
(TensorShape([2, 512]), TensorShape([2, 2]))

So, the shape of the context_vector is comparable here, but not the shape of attention_weights. The reason is, as we mentioned, the implementation of that tutorial kinda modified and adopted I believe. If we look at the calculation of BahdanauAttention or AdditiveAttention, we will get:

  1. Reshape query and value into shapes [batch_size, Tq, 1, dim] and [batch_size, 1, Tv, dim] respectively.
  2. Calculate scores with shape [batch_size, Tq, Tv] as a non-linear sum: scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)
  3. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).
  4. Use distribution to create a linear combination of values with shape batch_size, Tq, dim]: return tf.matmul(distribution, value).

And I think the implementation in that tutorials is a bit different for calculating the attention weight features. If we follow the above approach (1 to 4), we will get the same output shape for attention_weights as well. Here is how, (but not here is just a demonstration purpose, not general.)

class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    query_with_time_axis = tf.expand_dims(query, 2)  # [batch_size, Tq, 1, dim]
    value_with_time_axis = tf.expand_dims(values, 1) # [batch_size, 1, Tv, dim]
    scores = tf.reduce_sum(tf.tanh(query_with_time_axis + 
                                   value_with_time_axis), axis=-1)
    distribution = tf.nn.softmax(scores)
    return tf.matmul(distribution, values), distribution

Now, if we pass the same input, we will get the same output shape from both implementations. However, in general, use cases, the built-in implementation should be picked.

attention_layer = BahdanauAttention(10)

y = tf.random.uniform((2, 60, 512))  
out, attn = attention_layer(y, y)
out.shape , attn.shape
(TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))

buit_attn = tf.keras.layers.AdditiveAttention()
y = tf.random.uniform((2, 60, 512))  
out, attn = buit_attn([y, y], return_attention_scores=True)
out.shape , attn.shape
(TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))
Innat
  • 16,113
  • 6
  • 53
  • 101
  • Thanks for the answer. But then how can the ```built-in``` attention layer be used for problems which require inputs of shape ```(batch_size,units)``` say for ```Text Classification```. Because I tried passing the OP of the ```built-in``` attention layer and ```Flatten``` it before passing it to the ```Dense``` layer, then its throws an error. – data_person May 02 '21 at 10:02
  • You mean somethign like this `y = tf.random.uniform((2, 512)) ; out, attn = buit_attn([y, y], return_attention_scores=True)`? I'm not quite sure what are you up to but this should work. – Innat May 02 '21 at 11:00
  • Sorry If I was not clear. I am confused about using the ```OP``` of the ```built-in Attention Layer``` for ```Text Classification``` since its a ```3d OP```. When I try to pass the ```OP of Att layer``` through ```Flatten``` to further pass it through ```Dense``` it throws an error. – data_person May 02 '21 at 12:01
  • [This](https://keras.io/api/layers/attention_layers/additive_attention/) shows how to use this built-in layer. Otherwise, if you're simply following that tutorial, I think you should adopt that implementation. I wish I could help more. – Innat May 02 '21 at 12:31