3

Implementing self attention in tensorflow Keras with a bit modification ( e.g., residual (add connection)).

I have the following input shape:

myinput: KerasTensor(type_spec=TensorSpec(shape=(None, 8, 6, 64), dtype=tf.float32, name=None), name='multiply/mul:0', description="created by layer 'multiply'")

My goal is to process TensorSpec(shape=(None, 8, 6, 64) (8 time stamps one by one (6 * 64)) through self attention and get self attention feature map for every time stamp and then concatenate it again into output tensor shape (None, 8, 6, 64).

Implemented Code:

import tensorflow as tf
from tensorflow.keras.layers import Permute


def conv1d(channels, ks=1, strides=1, padding='same'):
    conv = tf.keras.layers.Conv1D(channels, ks, strides, padding, activation='relu', use_bias=False,
                                  kernel_initializer='HeNormal')
    return conv


class my_self_attention(tf.keras.layers.Layer):
    def __init__(self, channels):
        super(my_self_attention, self).__init__()
        self.query = conv1d(channels)
        self.key = conv1d(channels)
        self.value = conv1d(channels)
        self.gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

    def call(self, x):
        x = tf.reshape(x, shape=[-1, x.shape[2], x.shape[3]])
        f = self.query(x),
        g = self.key(x)
        h = self.value(x)
        attention_weights = tf.keras.activations.softmax(
            tf.matmul(g, Permute((2, 1))(f)))  # query multiply with key and then softmax on it
        sensor_att_fm = tf.matmul(attention_weights, h)
        o = self.gamma * sensor_att_fm + x
        # return tf.reshape(o, shape = [-1, 1, x.shape[1], x.shape[2]])
        return tf.reshape(o, shape=[-1, 1, x.shape[1], x.shape[2]])


 sa = my_self_attention(channels)
 refined_fm = tf.concat([sa(tf.expand_dims(my_input[:, t, :, :], 1)) for t in   range(my_input.shape[1])], 1)

Getting Following Error

ValueError: Dimension must be 4 but is 3 for '{{node my_self_attention/permute/transpose}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](my_self_attention/permute/transpose/a, my_self_attention/permute/transpose/perm)' with input shapes: [1,?,6,64], [3].

How should I fix this issue?

Ahmad
  • 645
  • 2
  • 6
  • 21

1 Answers1

2

Make sure the number of channels are the same as the last dimension of your input data. Also adding a comma here: f = self.query(x), creates a tuple, which you probably do not want: Here is a working example:

import tensorflow as tf
from tensorflow.keras.layers import Permute

def conv1d(channels, ks=1, strides=1, padding='same'):
  conv = tf.keras.layers.Conv1D(channels, ks, strides, padding, activation='relu', use_bias=False,
                                kernel_initializer='HeNormal')
  return conv

class my_self_attention(tf.keras.layers.Layer):
  def __init__(self, channels):
      super(my_self_attention, self).__init__()
      self.query = conv1d(channels)
      self.key = conv1d(channels)
      self.value = conv1d(channels)
      self.gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

  def call(self, x):
      x = tf.reshape(x, shape=[-1, x.shape[2], x.shape[3]])
      f = self.query(x)
      g = self.key(x)
      h = self.value(x)

      attention_weights = tf.keras.activations.softmax(
          tf.matmul(g, Permute((2, 1))(f)))  # query multiply with key and then softmax on it
      sensor_att_fm = tf.matmul(attention_weights, h)
      o = self.gamma * sensor_att_fm + x
      return tf.reshape(o, shape=[-1, 1, x.shape[1], x.shape[2]])


my_input = tf.random.normal((50, 8, 6, 64))
sa = my_self_attention(64)
refined_fm = tf.concat([sa(tf.expand_dims(my_input[:, t, :, :], 1)) for t in   range(my_input.shape[1])], 1)
AloneTogether
  • 25,814
  • 5
  • 20
  • 39
  • Oh, what silly mistake I did. The number channels was similar as the last dimension of input. But made typo of comma: `f = self.query(x),`. Just removed the comma and problem fix. Thanks a lot. – Ahmad Jan 31 '22 at 10:30
  • However now I am getting warning `WARNING:absl:Found untraced functions such as conv1d_layer_call_and_return_conditional_losses, conv1d_layer_call_fn, conv1d_1_layer_call_and_return_conditional_losses, conv1d_1_layer_call_fn, conv1d_2_layer_call_and_return_conditional_losses while saving (showing 5 of 15). – Ahmad Jan 31 '22 at 10:30
  • 1
    Hmm seems to be a new problem, because with the code you provided I don't get this error. Maybe post a new question with more information? – AloneTogether Jan 31 '22 at 10:32
  • Here it is: https://stackoverflow.com/questions/70924722/warningabslfound-untraced-functions-such-as-conv1d-layer-call-and-return-condi – Ahmad Jan 31 '22 at 10:51