I am trying to build a seq2seq autoencoder that should have the ability to capture the logic of a sequence and be able to reconstruct it from the state vectors. I am using some example sequences to test if the model is capable of doing a very simple version of the task. I have sequences that go like '<><><>...' or '(.)(.)(.)...' for this purpose.
The basic code of the model looks like this:
n_hidden = 256
emb_dim = 16
n_tokens = 559
#Encoder
enc_inp = Input(shape=(None,))
emb_layer = Embedding(input_dim=n_tokens, output_dim=emb_dim)
enc_emb = emb_layer(enc_inp)
enc_layer = LSTM(n_hidden, return_state=True)
enc, state_h, state_c = enc_layer(enc_emb)
#Decoder
dec_inp = Input(shape=(None,))
dec_emb = emb_layer(dec_inp)
dec_layer = LSTM(n_hidden, return_sequences=True, return_state=True)
dec, _, _ = dec_layer(dec_emb, initial_state=[state_h, state_c])
dense_layer = Dense(n_tokens, activation='softmax')
res = dense_layer(dec)
model = Model([enc_inp, dec_inp], res)
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
#Models for Inference
#Encoder
encoder_model = Model(enc_inp, [state_h, state_c])
#Decoder
state_inp1 = Input(shape=(n_hidden,))
state_inp2 = Input(shape=(n_hidden,))
dec, state_h, state_c = dec_layer(dec_emb, initial_state=[state_inp1, state_inp2])
res = dense_layer(dec)
decoder_model = Model([dec_inp] + [state_inp1, state_inp2], [res] + [state_h, state_c])
The encoder inputs are the integer encoded characters that get turned into vectors by the embedding layer. The decoder input are the same as the encoder inputs but with a start token appended to the beginning and therefore shifted by one to the right. The results that are used in training are the one-hot encoded encoder inputs.
Now the model ist not performing really well but only predicting the first character over and over again, so for:
Original: '<><><>...', Prediction: '<<<<<<...'
Original: '(.)(.)(.)...', Prediction: '((((((...'
Is it just a question of training or is there some crucial mistake I am making here?