...coming from TensorFlow, where pretty much any shape and everything is defined explicitly, I am confused about Keras' API for recurrent models. Getting an Elman network to work in TF was pretty easy, but Keras resists to accept the correct shapes...
For example:
x = k.layers.Input(shape=(2,))
y = k.layers.Dense(10)(x)
m = k.models.Model(x, y)
...works perfectly and according to model.summary()
I get an input layer with shape (None, 2)
, followed by a dense layer with output shape (None, 10)
. Makes sense since Keras automatically adds the first dimension for batch processing.
However, the following code:
x = k.layers.Input(shape=(2,))
y = k.layers.SimpleRNN(10)(x)
m = k.models.Model(x, y)
raises an exception ValueError: Input 0 is incompatible with layer simple_rnn_1: expected ndim=3, found ndim=2
.
It works only if I add another dimension:
x = k.layers.Input(shape=(2,1))
y = k.layers.SimpleRNN(10)(x)
m = k.models.Model(x, y)
...but now, of course, my input would not be (None, 2)
anymore.
model.summary()
:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 2, 1) 0
_________________________________________________________________
simple_rnn_1 (SimpleRNN) (None, 10) 120
=================================================================
How can I have an input of type batch_size
x 2
when I just want to feed vectors with 2 values to the network?
Furthermore, how would I chain RNN cells?
x = k.layers.Input(shape=(2, 1))
h = k.layers.SimpleRNN(10)(x)
y = k.layers.SimpleRNN(10)(h)
m = k.models.Model(x, y)
...raises the same exception with incompatible dim sizes.
This sample here works:
x = k.layers.Input(shape=(2, 1))
h = k.layers.SimpleRNN(10, return_sequences=True)(x)
y = k.layers.SimpleRNN(10)(h)
m = k.models.Model(x, y)
...but then layer h
does not output (None, 10)
anymore, but (None, 2, 10)
since it returns the whole sequence instead of just the "regular" RNN cell output.
Why is this needed at all?
Moreover: where are the states? Do they just default to 1 recurrent state?