1

I want to calculate the gradient of the loss function with respect to all the hidden states in a 2-dimensional LSTM, which is implemented using TensorArray and tf.while_loop (refer to this repo for implementation details). The basic idea is that we have one TensorArray (with length time_steps) to store all the hidden states and we use tf.while_loop to calculate hidden states step by step. In the body of the while loop, we calculate h_t based on h_t-1, like:

h_t-1 = h_tensor_array.read(t-1) # read previous hidden states from TensorArray
x_t = x_tensor_array.read(t) # read current input from TensorArray
o_t, h_t = cell(x_t, h_t-1) # RNN cell
h_tensor_array.write(t, h_t) # update hidden states TensorArray
o_tensor_array.write(t, o_t) # update output TensorArray

If I simply use

all_hidden_states = h_tensor_array.stack() # get all hidden states Tensor
tf.gradients(loss, all_hidden_states) # None

(the loss is calculated based on the final hidden states), the gradients are None. I suspect that each time we call stack or read methods of a TensorArray, the returned Tensors are not exactly the same instance, which mean that they (h_t-1 and all_hidden_states) are different nodes in the computational graph. Because we use h_t-1 in cell, there is a path between loss and h_t-1, but no path exists between loss and all_hidden_states. Does anyone have an idea about how to get the gradients for all hidden states? Thanks!

jzb
  • 53
  • 6
  • same None error in gradients here, when computing gradients of the hidden states with respect to the input. Have you solved it yet? https://stackoverflow.com/questions/71107377/tensorflow-obtain-rnn-hidden-states-gradients-with-respect-to-input – siegfried Feb 14 '22 at 05:44

0 Answers0