9

I'd like to implement an encoder-decoder architecture based on a LSTM or GRU with an attention layer. I saw that Keras has a layer for that tensorflow.keras.layers.Attention and I'd like to use it (all other questions and resources seem to implement it themselves or use third party libraries). Also I'm not using the network for sequence to sequence translation but for binary classification, therefore the example provided in the documentation is a bit confusing to me.

I'm imagining a model like this.

import tensorflow as tf

x = tf.keras.Input((100, 50))

# encoder
hidden_states = tf.keras.layers.GRU(32, return_sequences=True)(x)

# decoder + attention
? = tf.keras.layers.Attention()([?, ?])
z = tf.keras.layers.GRU(32)(?)

# classification
z = tf.keras.layers.Dense(1, activation='sigmoid')(z)

model = tf.keras.Model(inputs=x, outputs=z)

The decoder and attention part are of this network are unclear to me. I know that I need to create a context vector from the hidden states of the encoder and the decoders current hidden state.

How would I implement the decoder and attention part of this network?

TomTom
  • 2,820
  • 4
  • 28
  • 46

1 Answers1

10

It may be a late answer for you but for future reference, here I'm providing a starter code basis according to your imagination model alike. Currently, there are three built-in attention layers, namely

- MultiHeadAttention layer
- Attention layer (a.k.a. Luong-style attention)
- AdditiveAttention layer (a.k.a. Bahdanau-style attention)

For the starter code, we'll be using Luong-style in the encoder part and Bahdanau-style attention mechanism in the decoder part. The overall autoencoder architecture would be

a. encoder: input -> embedding -> gru -> luong-style-attn 

b. decoder: input -> lstm -> bahdanau-style-attn -> gap -> classifier
              ↓_____________________

# whole model 
autoencoder: encoder + decoder 

Let's build the model accordingly.


Encoder

from tensorflow.keras import Input, Model 
from tensorflow.keras.layers import * 
from tensorflow.keras import backend 
from tensorflow.keras import utils
backend.clear_session()

# int sequences.
enc_inputs = Input(shape=(20,), name='enc_inputs')

# Embedding lookup and GRU
embedding = Embedding(input_dim=100, output_dim=64)(enc_inputs)
whole_sequence = GRU(4, return_sequences=True)(embedding)

# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = Attention()([whole_sequence, whole_sequence])

# build encoder model 
encoder = Model(enc_inputs, query_value_attention_seq, name='encoder')

checking the layout.

utils.plot_model(encoder, show_shapes=True)

enter image description here

Decoder

# int sequences.
dec_input = Input(shape=(20, 4), name='dec_inputs')

# LSTM
whole_sequence = LSTM(4, return_sequences=True)(dec_input)

# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = AdditiveAttention()([whole_sequence, dec_input])

# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_value_attention = GlobalAveragePooling1D()(query_value_attention_seq)

# classification
dec_output = Dense(1, activation='sigmoid')(query_value_attention)

# build decoder model
decoder = Model(dec_input, dec_output, name='decoder')

checking the layout.

enter image description here

Autoencoder

# encoder 
encoder_init = Input(shape=(20, ))
encoder_output = encoder(encoder_init); print(encoder_output.shape)

# decoder 
decoder_output = decoder(encoder_output); print(decoder_output.shape)

# bind all: autoencoder 
autoencoder = Model(encoder_init, decoder_output)

# check layout 
utils.plot_model(autoencoder, show_shapes=True, expand_nested=True)

enter image description here

Dummy Training

x_train = np.random.randint(0, 10, (100,20)); print(x_train.shape)
y_train = np.random.randint(2, size=(100, 1)); print(y_train.shape)
(100, 20)
(100, 1)

autoencoder.compile('adam', 'binary_crossentropy')
autoencoder.fit(x_train, y_train, epochs=5, verbose=2)

Epoch 1/5
4/4 - 4s - loss: 0.6674
Epoch 2/5
4/4 - 0s - loss: 0.6637
Epoch 3/5
4/4 - 0s - loss: 0.6600
Epoch 4/5
4/4 - 0s - loss: 0.6590
Epoch 5/5
4/4 - 0s - loss: 0.6571

Resources

Also, you can read out my other answer regarding the attention mechanism.

And this one is my favorite about the multi-head transformer, it's a video of 3 series.

Innat
  • 16,113
  • 6
  • 53
  • 101