5

I'm trying to run an RNN trained in Keras in an application that run's in real-time. The "time" in the recurrent network (it's an LSTM) here are actual moments in time when data is received.

I want to get the output of the RNN in an online fashion. For non-recurrent models, I just shaped my input into the shape inputDatum=1,input_shape and run Model.predict on it. I'm not sure this is the intended method of using forward pass in Keras for the application, but it worked for me.

But for recurrent modules, Model.predict expects as input the whole input, including temporal dimension. So it does not work...

Is there a way to do this in Keras or do I need to go down to Tensorflow and implement the operation there?

psacawa
  • 312
  • 2
  • 7

1 Answers1

5

You can set the LSTM layer to be stateful. The internal state of the LSTM will be kept until you call model.reset_states() manually.

For example, suppose we have trained a simple LSTM model.

x = Input(shape=(None, 10))
h = LSTM(8)(x)
out = Dense(4)(h)
model = Model(x, out)
model.compile(loss='mse', optimizer='adam')

X_train = np.random.rand(100, 5, 10)
y_train = np.random.rand(100, 4)
model.fit(X_train, y_train)

Then, the weights can be loaded onto another model with stateful=True for prediction (remember to set up batch_shape in the Input layer).

x = Input(batch_shape=(1, None, 10))
h = LSTM(8, stateful=True)(x)
out = Dense(4)(h)
predict_model = Model(x, out)

# copy the weights from `model` to this model
predict_model.set_weights(model.get_weights())

For your use case, since predict_model is stateful, consecutive predict calls on length-1 sub-sequences will give the same result as predicting on the entire sequence. Just remember to call reset_states() before predicting a new sequence.

X = np.random.rand(1, 3, 10)
print(model.predict(X))
# [[-0.09485822,  0.03324107,  0.243945  , -0.20729265]]

predict_model.reset_states()
for t in range(3):
    print(predict_model.predict(X[:, t:(t + 1), :]))
# [[-0.04117237 -0.06340873  0.10212967 -0.06400848]]
# [[-0.12808001  0.0039286   0.23223262 -0.23842749]]
# [[-0.09485822  0.03324107  0.243945   -0.20729265]]
Yu-Yang
  • 14,539
  • 2
  • 55
  • 62