0

I meet a problem when I check the "TPU Compatibility" of a bidirectional rnn. The TensorBoard tell me a reversal operation of sequence length vector is incompatible on TPU. I don't know why?

My simple code:

X_batch = np.array([
                    [[0., 1., 2.], [8., 2., 1.], [9., 8., 7.]],
                    [[3., 4., 5.], [9., 7., 4.], [0., 0., 0.]],
                    [[6., 7., 8.], [3., 6., 7.], [6., 5., 4.]],
                    [[9., 0., 1.], [0., 0., 0.], [0., 0., 0.]]
                   ])
seq_length_batch = np.array([3, 2, 3, 1])

batch = 4
n_steps = 3
input_size = 3
inputs = tf.placeholder(tf.float32, [batch, n_steps, input_size])
seq_len = tf.placeholder(tf.int32, [None])

def biLSTM(inputs, seq_len, n_hidden, batch_size):
    lstm_fw = tf.nn.rnn_cell.LSTMCell(n_hidden, state_is_tuple=True)
    lstm_bw = tf.nn.rnn_cell.LSTMCell(n_hidden, state_is_tuple=True)

    _initial_state_fw = lstm_fw.zero_state(batch_size, tf.float32)
    _initial_state_bw = lstm_bw.zero_state(batch_size, tf.float32)

    output, _states = tf.nn.bidirectional_dynamic_rnn(lstm_fw, lstm_bw, inputs,
                                   initial_state_fw=_initial_state_fw,
                                   initial_state_bw=_initial_state_bw,
                                   sequence_length=seq_len)

    final_outputs = tf.concat([output[0], output[1]], 2)
    return final_outputs

biLSTM_model = biLSTM(inputs, seq_len, 4, batch)

with tf.Session() as sess:
    check_write = tf.summary.FileWriter('../test_tensorboard', sess.graph)
    init = tf.global_variables_initializer()
    init.run()
    print(sess.run(biLSTM_model, feed_dict={inputs: X_batch,
                                           seq_len: seq_length_batch}))

TensorBoard Screenshots:

  1. Incompatible Operation

  2. ReverseSequence1

  3. ReverseSequence2

aman2930
  • 275
  • 2
  • 9
  • Which version of TensorFlow? It was implemented somewhat recently (should be in 1.7+ at least), but otherwise [looks supported](https://cloud.google.com/tpu/docs/tensorflow-ops). – Allen Lavoie May 07 '18 at 23:26

1 Answers1

0

I have the same issue, I went as far as implementing it on a TPU and ran into the roadblock that it cannot be unrolled because it contains a conditional while loop (loop until end of the input is reached)

One possible fix is to pad your input data to constant lengths, and to change the conditional while loop to one with a static length.