0

Trying to understand recurrent neural nets better based on a simple example where the training set is of the form of n ones followed by a minus one (e.g., train_set =[*([1]*n),-1]*10_000). Would like to find an architecture that is able to converge to zero error on the training set for different values of n. In particular, can this be accomplished with stateful RNNs and window_size=1?

Below is an example that does not converge:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN
from tensorflow.keras.optimizers import Adagrad
from keras.callbacks import Callback
from matplotlib import pyplot as plt


window = 1
train_set =[*([1]*13),-1]*(10_000+window)
batch_size = 1000

feature = []
target = []
for i in range(len(train_set) - window):
    feature.append(train_set[i:i+window])
    target.append(train_set[i+window])
    
feature = np.array(feature).reshape(len(feature), window)
target = np.array(target)

len_data = len(feature)//batch_size*batch_size
feature = feature[:len_data]
target = target[:len_data]


hidden_size = 20

optimizer = tf.keras.optimizers.Adam()

model = Sequential()
model.add(SimpleRNN(hidden_size, batch_input_shape=(batch_size, window, 1), stateful=True))
model.add(Dense(1, activation='linear'))
model.compile(loss='mse', optimizer=optimizer)

epochs = 25
for i in range(epochs):
    model.reset_states()
    model.fit(feature, target,
              epochs=1, batch_size=batch_size, verbose=1, shuffle=False,
    )

predictions = model.predict(feature, batch_size=batch_size)
predictions = np.squeeze(predictions)
plt.figure(figsize=(14, 5))
plt.plot(target[-100:], marker='o', label='target')
plt.plot(predictions[-100:], marker='+', label='pred')

140/140 [==============================] - 1s 2ms/step - loss: 0.3727
140/140 [==============================] - 0s 2ms/step - loss: 0.2595
140/140 [==============================] - 0s 2ms/step - loss: 0.2531
140/140 [==============================] - 0s 2ms/step - loss: 0.2549
140/140 [==============================] - 0s 2ms/step - loss: 0.2527
140/140 [==============================] - 0s 2ms/step - loss: 0.2526
140/140 [==============================] - 0s 2ms/step - loss: 0.2518
140/140 [==============================] - 0s 2ms/step - loss: 0.2512
140/140 [==============================] - 0s 2ms/step - loss: 0.2506
140/140 [==============================] - 0s 2ms/step - loss: 0.2502
140/140 [==============================] - 0s 2ms/step - loss: 0.2500
140/140 [==============================] - 0s 2ms/step - loss: 0.2500
140/140 [==============================] - 0s 2ms/step - loss: 0.2501
140/140 [==============================] - 0s 2ms/step - loss: 0.2502
140/140 [==============================] - 0s 2ms/step - loss: 0.2502
140/140 [==============================] - 0s 2ms/step - loss: 0.2503
140/140 [==============================] - 0s 2ms/step - loss: 0.2503
140/140 [==============================] - 0s 2ms/step - loss: 0.2502
140/140 [==============================] - 0s 2ms/step - loss: 0.2504
140/140 [==============================] - 0s 2ms/step - loss: 0.2511
140/140 [==============================] - 0s 3ms/step - loss: 0.2514
140/140 [==============================] - 0s 2ms/step - loss: 0.2476
140/140 [==============================] - 0s 2ms/step - loss: 0.2465
140/140 [==============================] - 0s 2ms/step - loss: 0.2481
140/140 [==============================] - 0s 2ms/step - loss: 0.2506

predictions are in orange: not much has been learned

0 Answers0