0

Can someone tell me, in the TensorFlow framework, how to initialize the hidden states of a LSTM network with user-defined values? I am trying to incorporate side information to the LSTM by giving a specific hidden state of the first LSTM cell.

Q.Wang
  • 1
  • 1

1 Answers1

0

You can pass the initial hidden state of the LSTM by a parameter initial_state of the function responsible to unroll the graph.

I am assuming that you will use some of the following functions in tensorflow to create the Recurrent Neural Network (RNN): tf.nn.dynamic_rnn, bidirectional_dynamic_rnn, tf.nn.static_rnn, or tf.nn.static_bidirectional_rnn . All of them have a initial_state parameter. In the case of a Bidirectional RNN, you need to pass the initial states for both forward (initial_state_fw) and backward (initial_state_bw) passes.

Example that defines a model with tf.nn.dynamic_rnn:

import tensorflow as tf

batch_size = 32
max_sequence_length = 100
num_features = 128
num_units = 64 

input_sequence = tf.placeholder(tf.float32, shape=[batch_size, max_sequence_length, num_features])
input_sequence_lengths = tf.placeholder(tf.int32, shape=[batch_size])

cell = tf.nn.rnn_cell.LSTMCell(num_units=num_units, state_is_tuple=True)

# Initial states
cell_state = tf.zeros([batch_size, num_units], tf.float32)
hidden_state = tf.placeholder(tf.float32, [batch_size, num_units])
my_initial_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

outputs, states = tf.nn.dynamic_rnn(
                    cell=cell,
                    inputs=input_sequence,
                    initial_state=my_initial_state,
                    sequence_length=input_sequence_lengths)

Since we use state_is_tuple=True, we need to pass an initial state that is a tuple of the cell_state and the hidden_state. In the documentation of LSTMCell this tuple corresponds to c_state and m_state, which previous discussion points out that this represents the cell state and hidden state, respectively.

Therefore, since we only want to initialize the first hidden state, the cell_state is initialized with zeros.

K. Bogdan
  • 446
  • 3
  • 11