4

I am wondering if PyTorch Tensors where the Python Variables are overwritten are still being kept in the computational graph of PyTorch.


So here is a small example, where I have an RNN Model where hidden states (and some other variables) are reset after every iteration, backward() is called later.

Example:

for i in range(5):
   output = rnn_model(inputs[i])
   loss += criterion(output, target[i])
   ## hidden states are overwritten with a zero vector
   rnn_model.reset_hidden_states() 
loss.backward()

So my question is:

  • Is there a problem in overwriting the hidden states before calling backward()?

  • Or does the computational graph keep the necessary information of the hidden states of previous iterations in memory to compute the gradients?

  • Edit: It would be great to have a statement of an official source for this. e.g. stating that all variables relevant for the CG are kept - no matter if there are still other python references for this variables. I assume that there is a reference in the graph itself preventing the garbage collector from deleting it. But I'd like to know if this is truly the case.

Thanks in advance!

MBT
  • 21,733
  • 19
  • 84
  • 102

1 Answers1

0

I think it is ok to reset before backward. The graph preserves the required information.

class A (torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.f1 = torch.nn.Linear(10,1)
     def forward(self, x):
         self.x = x 
         return torch.nn.functional.sigmoid (self.f1(self.x))
     def reset_x (self):
        self.x = torch.zeros(self.x.shape) 
net = A()
net.zero_grad()
X = torch.rand(10,10) 
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params: 
    print(i.grad)
net.zero_grad() 

loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()     
params = list(net.parameters())
for i in params:
    print(i.grad)

In the above code, I print the grads with/ without resetting input x. Gradient depends on x for sure and resetting it doesn't matter. Therefore, I think graph preserves information to do the backward op.

Umang Gupta
  • 15,022
  • 6
  • 48
  • 66
  • Hi, thanks for your answer! Do you have any reference for this? I've used it a couple of in this way, empirically I've come to the same conclusion, however I wonder if this is always the case. – MBT Oct 13 '18 at 15:44
  • Thank you as well for the example, but you have to change it a bit in order make it conclusive: With `self.x` you are only deleting one (of two) references, `X` is still a valid reference. You will see that `X is net.x` returns `True`. So in order to correct this you have to either delete `X` as well or you have to change `self.x = x` to `self.x = x.clone()`. Correcting this it works, but as stated above I'm wondering if this conclusion is always valid. – MBT Oct 13 '18 at 15:45
  • after resetting `net.x is X` is `false` and ok I add `delete X` and it is still `false`. – Umang Gupta Oct 13 '18 at 16:34
  • No, what I ment is that `self.x = X` will result in `self.x is X`. So overwriting `self.x` will not delete the tensor in `X` as `X` is still a valid reference to this tensor. Of course after reseting they are not the same... – MBT Oct 13 '18 at 16:46
  • Yes, but using this you can imagine all sort of overwrite and still it computes properly. Even after only over-writing the data property I am getting correct answers – Umang Gupta Oct 13 '18 at 17:15