4

I am trying to understand how to use the tf.keras.layers.Attention shown here:

Tensorflow Attention Layer

I am trying to use it with encoder decoder seq2seq model. Below is my code:

encoder_inputs = Input(shape=(max_len_text,)) 
enc_emb = Embedding(x_voc_size, latent_dim,trainable=True)(encoder_inputs) 
encoder_lstm=LSTM(latent_dim, return_state=True, return_sequences=True) 
encoder_outputs, state_h, state_c= encoder_lstm(enc_emb) 

decoder_inputs = Input(shape=(max_len_summary,)) 
dec_emb_layer = Embedding(y_voc_size, latent_dim,trainable=True) 
dec_emb = dec_emb_layer(decoder_inputs) 

decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) 
decoder_outputs,decoder_fwd_state, decoder_back_state = decoder_lstm(dec_emb,initial_state=[state_h, state_c]) 

My question is, how to use the given Attention layer in keras with this model? I am not able to understand their document.

1 Answers1

1

If you are using RNN, I would not recommend using the above class.

While analysing tf.keras.layers.Attention Github code to better understand your conundrum, the first line I could come across was - "This class is suitable for Dense or CNN networks, and not for RNN networks"

I would recommend writing your own seq to seq model which can be done with less than a dozen lines of code. For e.g.: https://www.tensorflow.org/tutorials/text/nmt_with_attention

To write your own custom attention layer(based on whether you prefer Bahdanau, Luong, Raffel, Yang etc), perhaps this post outlining a basic essence may help: Custom Attention Layer using in Keras

Allohvk
  • 915
  • 8
  • 14
  • 3
    Ah this makes sense. I could never get this layer to work with LSTM networks. I think you need to write custom training loops in that case with a custom attention layer. Basically, just as the tutorial says, you need to iterate over the decoder one at a time, using the encoder sequence, especially if you want teacher forcing, which generally you do. It doesn't seem like you can cheat on this and just feed in the full decoded sequence, but I think that makes sense because the state needs to be passed after every prediction, which includes the previous context vector. – neuroguy123 Dec 01 '20 at 21:31