I want to calculate the gradient of the loss function with respect to all the hidden states in a 2-dimensional LSTM, which is implemented using TensorArray
and tf.while_loop
(refer to this repo for implementation details). The basic idea is that we have one TensorArray
(with length time_steps
) to store all the hidden states and we use tf.while_loop
to calculate hidden states step by step. In the body of the while loop, we calculate h_t
based on h_t-1
, like:
h_t-1 = h_tensor_array.read(t-1) # read previous hidden states from TensorArray
x_t = x_tensor_array.read(t) # read current input from TensorArray
o_t, h_t = cell(x_t, h_t-1) # RNN cell
h_tensor_array.write(t, h_t) # update hidden states TensorArray
o_tensor_array.write(t, o_t) # update output TensorArray
If I simply use
all_hidden_states = h_tensor_array.stack() # get all hidden states Tensor
tf.gradients(loss, all_hidden_states) # None
(the loss is calculated based on the final hidden states), the gradients are None
. I suspect that each time we call stack
or read
methods of a TensorArray
, the returned Tensor
s are not exactly the same instance, which mean that they (h_t-1
and all_hidden_states
) are different nodes in the computational graph. Because we use h_t-1
in cell
, there is a path between loss
and h_t-1
, but no path exists between loss
and all_hidden_states
. Does anyone have an idea about how to get the gradients for all hidden states? Thanks!