0

I have a situation where for each mini-batch, I have multiple nested data, for which model need to be trained.

for idx, batch in enumerate(train_dataloader):
data = batch.get("data").squeeze(0)
op = torch.zeros(size) #zero_initializations
for i in range(data.shape[0]):
    optimizer.zero_grad()
    current_data = data[i, ...]
    start_to_current_data = data[:i+1, ...]
    target =  some_transformation_func(start_to_current_data)
    op = model(current_data, op)
    loss = criterion(op, target)
    loss.backward()
    optimizer.step()

But when I start training, I get the following error RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time. Setting retain_graph=True increase the memory usage and I can not train the model. How can I fix this.

Rajat Sharma
  • 47
  • 1
  • 15
  • Hello, this is due to the instruction `op = model(current_data, op)`. Since you never detach `op` from the computational graph, it will backprop through previous computations. To help find a fix, I would need you to explain what you are trying to achieve here. Which parameter are you trying to optimize, and why are you re-using `op` at every iteration ? – trialNerror Jun 23 '21 at 12:04
  • Thanks for the feedback. I want to feed the past output with the current data to find the final output. For the first n epochs, the past output is calculated from some function and fed along with the input to the model. As the validation errors b/w model op and target decreases I want to fine-tune the model such that model output is used as past output instead of calculating it from some function. I hope you understand. – Rajat Sharma Jun 23 '21 at 21:27
  • okay, I mostly needed to know whether you needed to backprop anything through `op` but that does not seem to be the case, so I think you can try to replace the insruction with `op = model(current_data, op.detach())`. I believe this should work – trialNerror Jun 24 '21 at 12:07

0 Answers0