I am wondering if PyTorch Tensors where the Python Variables are overwritten are still being kept in the computational graph of PyTorch.
So here is a small example, where I have an RNN Model where hidden states (and some other variables) are reset after every iteration,
backward()
is called later.
Example:
for i in range(5):
output = rnn_model(inputs[i])
loss += criterion(output, target[i])
## hidden states are overwritten with a zero vector
rnn_model.reset_hidden_states()
loss.backward()
So my question is:
Is there a problem in overwriting the hidden states before calling
backward()
?Or does the computational graph keep the necessary information of the hidden states of previous iterations in memory to compute the gradients?
Edit: It would be great to have a statement of an official source for this. e.g. stating that all variables relevant for the CG are kept - no matter if there are still other python references for this variables. I assume that there is a reference in the graph itself preventing the garbage collector from deleting it. But I'd like to know if this is truly the case.
Thanks in advance!