3

I am new to Keras and LSTMs -- I want to train a model on 2-dimensional sequences (ie, movement in a grid-space), as opposed to 1-dimensional sequences (like characters of text).

As a test, I first tried just one dimension, and I am doing it successfully with the following setup:

model = Sequential()
model.add(LSTM(512, return_sequences=True, input_shape=X[0].shape, dropout=0.2, recurrent_dropout=0.2))
model.add(LSTM(512, return_sequences=False, dropout=0.2))
model.add(Dense(len(y[0]), activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=['accuracy'])
model.fit(X, y, epochs=50)

I'm formatting the data like this:

data = ## list of integers (1D)
inputs = []
outputs = []
for i in range(len(data) - SEQUENCE_LENGTH):
    inputs.append(data[i:i + SEQUENCE_LENGTH])
    outputs.append(data[i + SEQUENCE_LENGTH])
X = np.array([to_categorical(np.array(input), CATEGORY_LENGTH) for input in inputs])
y = to_categorical(np.array(outputs), CATEGORY_LENGTH)

This is straightforward and converges quickly.

But if instead of a list of integers, my data consists of 2D tuples, I can no longer create categorical (one-hot) arrays to pass to the LSTM layers.

I've tried not using categorical arrays and simply passing the tuples to the model. In this case, I've changed my output layer to:

model.add(Dense(1, activation="linear"))

But that does not converge, or at least moves incredibly slowly.

How can I adapt this code to handle input with additional dimensions?

Bish
  • 113
  • 2
  • 6
  • I am not sure is that you changed several factors at the same time or the necessity to change the another factor. Can you tell me: 1. why are you using categorical data instead of 2d array for grid location sequence? 2. why is this line necessary for tuple? `model.add(Dense(1, activation="linear"))` –  May 31 '19 at 02:59

1 Answers1

-1

This previous answer should apply to your question as well. The only difference is that you will have to convert your tuple to a data frame beforehand.

igodfried
  • 877
  • 9
  • 22