I am new to using keras and want to create a model with the structure like input>>attention>>LSTM>>attention>>output
But an error occurred when I ran model.fit, it gave the error of broadcastable shapes, but I don't understand why it can be created by model.summary but cannot run in model.fit.
While model.fit can run successfully when the structure only has one attention layer, like input>>LSTM>>attention>>output. I looked it up and it seems to be related to the class does anyone have any similar experience or have another way to have two attention layers implemented?
It can be created successfully like below. ''' model_attention.summary() ''' the output:
Model: "model_12"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 150, 5)] 0
_________________________________________________________________
attention_i_10 (attention_i) (None, 150, 5) 155
_________________________________________________________________
lstm_13 (LSTM) (None, 150, 32) 4864
_________________________________________________________________
attention_o_12 (attention_o) (None, 32) 182
_________________________________________________________________
dense_12 (Dense) (None, 3) 99
=================================================================
Total params: 5,300
Trainable params: 5,300
Non-trainable params: 0
_________________________________________________________________
Train on 1347850 samples
Epoch 1/5
But when running with model_attention.fit (whole code)
class attention_i(Layer):
def __init__(self,**kwargs):
super(attention_i,self).__init__(**kwargs)
def build(self,input_shape):
self.W=self.add_weight(name='attention_weight', shape=(input_shape[-1],1),
initializer='random_normal', trainable=True)
self.b=self.add_weight(name='attention_bias', shape=(input_shape[1],1),
initializer='zeros', trainable=True)
super(attention_i, self).build(input_shape)
def call(self,x):
# Alignment scores. Pass them through tanh function
e = K.tanh(K.dot(x,self.W)+self.b)
# Remove dimension of size 1
# e = K.squeeze(e, axis=-1) # output shape
# Compute the weights
alpha = K.softmax(e)
# Reshape to tensorFlow format
alpha = K.expand_dims(alpha, axis=-1)
# Compute the context vector
context = x * alpha
context = K.sum(context, axis=1)
return context
class attention_o(Layer):
def __init__(self,**kwargs):
super(attention_o,self).__init__(**kwargs)
def build(self,input_shape):
self.W=self.add_weight(name='attention_weight', shape=(input_shape[-1],1),
initializer='random_normal', trainable=True)
self.b=self.add_weight(name='attention_bias', shape=(input_shape[1],1),
initializer='zeros', trainable=True)
super(attention_o, self).build(input_shape)
def call(self,x):
# Alignment scores. Pass them through tanh function
e = K.tanh(K.dot(x,self.W)+self.b)
# Remove dimension of size 1
e = K.squeeze(e, axis=-1) # output shape
# Compute the weights
alpha = K.softmax(e)
# Reshape to tensorFlow format
alpha = K.expand_dims(alpha, axis=-1)
# Compute the context vector
context = x * alpha
context = K.sum(context, axis=1)
return context
def create_RNN_with_attention(hidden_units, dense_units, input_shape, activation):
x=Input(shape=input_shape)
attention_layer_input = attention_i()(x)
LSTM_layer = tf.keras.layers.LSTM(
hidden_units,
recurrent_regularizer=tf.keras.regularizers.l1_l2(l1=0, l2=0.01),
return_sequences=True,
activation='tanh')(attention_layer_input)
attention_layer_output = attention_o()(LSTM_layer)
outputs=Dense(dense_units, trainable=True, activation=activation)(attention_layer_output)
model=Model(x,outputs)
model.compile(loss='mse', optimizer='adam')
return model
model_attention = create_RNN_with_attention(hidden_units=hidden_units, dense_units=3,
input_shape=(time_steps,5), activation='tanh')
model_attention.compile(optimizer='adam',loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))
with tf.device('/device:GPU:0'):
model_attention.summary()
model_attention.fit(Feature_train, Target_train, epochs=epochs, batch_size=128, verbose=1)
An error appears like this.
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: required broadcastable shapes
[[{{node attention_i_10/mul}}]]
[[loss_14/AddN/_1303]]
(1) Invalid argument: required broadcastable shapes
[[{{node attention_i_10/mul}}]]
0 successful operations.
0 derived errors ignored.