2

I'm building a model that requires the network to be copied before training so there is an "old" and "new" network. Training is only performed on the new network, and the old network is static. The magnitude of the training update is clipped depending on how different the two networks are to prevent large updates (see https://arxiv.org/abs/1707.06347)

In tf.layers, it is easy to set a trainable flag like so:

def _build_cnet(self, name, trainable):
    w_reg = tf.contrib.layers.l2_regularizer(L2_REG)

    with tf.variable_scope(name):
        l1 = tf.layers.dense(self.state, 400, tf.nn.relu, trainable=trainable,
                             kernel_regularizer=w_reg, name="vf_l1")
        l2 = tf.layers.dense(l1, 400, tf.nn.relu, trainable=trainable, kernel_regularizer=w_reg, name="vf_l2")
        vf = tf.layers.dense(l2, 1, trainable=trainable, kernel_regularizer=w_reg, name="vf_out")
    params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
    return vf, params

I'm trying to add a LSTM layer at the end of my network like this:

def _build_cnet(self, name, trainable):
    w_reg = tf.contrib.layers.l2_regularizer(L2_REG)

    with tf.variable_scope(name):
        c_lstm = tf.contrib.rnn.BasicLSTMCell(CELL_SIZE)
        self.c_init_state = c_lstm.zero_state(batch_size=1, dtype=tf.float32)

        l1 = tf.layers.dense(self.state, 400, tf.nn.relu, trainable=trainable,
                             kernel_regularizer=w_reg, name="vf_l1")
        l2 = tf.layers.dense(l1, 400, tf.nn.relu, trainable=trainable, kernel_regularizer=w_reg, name="vf_l2")

        # LSTM layer
        c_outputs, self.c_final_state = tf.nn.dynamic_rnn(cell=c_lstm, inputs=tf.expand_dims(l2, axis=0),
                                                          initial_state=self.c_init_state)
        c_cell_out = tf.reshape(c_outputs, [-1, CELL_SIZE], name='flatten_lstm_outputs')

        vf = tf.layers.dense(c_cell_out, 1, trainable=trainable, kernel_regularizer=w_reg, name="vf_out")
    params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
    return vf, params

Is there an easy way to add a trainable flag to either tf.contrib.rnn.BasicLSTMCell or tf.nn.dynamic_rnn?

It seems like RNNCell has a trainable flag, but BasicLSTMCell doesn't?

Anjum Sayed
  • 872
  • 9
  • 20

0 Answers0