2

I am relatively new to Pytorch. Here I want to use this model to generate some images, however as this was written before Pytorch 1.5, since the gradient calculation has been fixed then, this is the error message.

RuntimeError: one of the variables needed for gradient computation has been 
modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] 
is at version 2; expected version 1 instead. 
Hint: enable anomaly detection to find the operation that 
failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I have looked at past examples and am not sure what is the problem here, I believe it is happening within this region but I don’t know where! Any help would be greatly appreciated!

def process(self, images, edges, masks):
self.iteration += 1

    # zero optimizers
    self.gen_optimizer.zero_grad()
    self.dis_optimizer.zero_grad()


    # process outputs
    outputs = self(images, edges, masks)
    gen_loss = 0
    dis_loss = 0


    # discriminator loss
    dis_input_real = torch.cat((images, edges), dim=1)
    dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
    dis_real, dis_real_feat = self.discriminator(dis_input_real)        # in: (grayscale(1) + edge(1))
    dis_fake, dis_fake_feat = self.discriminator(dis_input_fake)        # in: (grayscale(1) + edge(1))
    dis_real_loss = self.adversarial_loss(dis_real, True, True)
    dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
    dis_loss += (dis_real_loss + dis_fake_loss) / 2


    # generator adversarial loss
    gen_input_fake = torch.cat((images, outputs), dim=1)
    gen_fake, gen_fake_feat = self.discriminator(gen_input_fake)        # in: (grayscale(1) + edge(1))
    gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
    gen_loss += gen_gan_loss


    # generator feature matching loss
    gen_fm_loss = 0
    for i in range(len(dis_real_feat)):
        gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
    gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT
    gen_loss += gen_fm_loss


    # create logs
    logs = [
        ("l_d1", dis_loss.item()),
        ("l_g1", gen_gan_loss.item()),
        ("l_fm", gen_fm_loss.item()),
    ]

    return outputs, gen_loss, dis_loss, logs

def forward(self, images, edges, masks):
    edges_masked = (edges * (1 - masks))
    images_masked = (images * (1 - masks)) + masks
    inputs = torch.cat((images_masked, edges_masked, masks), dim=1)
    outputs = self.generator(inputs)                                    # in: [grayscale(1) + edge(1) + mask(1)]
    return outputs

def backward(self, gen_loss=None, dis_loss=None):
    if dis_loss is not None:
        dis_loss.backward()
    self.dis_optimizer.step()

    if gen_loss is not None:
        gen_loss.backward()
    self.gen_optimizer.step()

Thank you!

Linux Penguin
  • 23
  • 1
  • 4

3 Answers3

4

You can't compute the loss for the discriminator and for the generator in one go and have the both back-propagations back-to-back like this:

if dis_loss is not None:
    dis_loss.backward()
self.dis_optimizer.step()

if gen_loss is not None:
    gen_loss.backward()
self.gen_optimizer.step()

Here's the reason why: when you call self.dis_optimizer.step(), you effectively in-place modify the parameters of the discriminator, the very same that were used to compute gen_loss which you are trying to backpropagate on. This is not possible.

You have to compute dis_loss backpropagate, update the weights of the discriminator, and clear the gradients. Only then can you compute gen_loss with the newly updated discriminator weights. Finally, backpropagate on the generator.

This tutorial is a good walkthrough over a typical GAN training.

Ivan
  • 34,531
  • 8
  • 55
  • 100
0

This might not be an answer exactly to your question but I got this when trying to use a "custom" distributed optimizer e.g. I was using Cherry's optimizer and accidentially moving the model to a DDP model at the same time. Once I only moved the model to device according to how cherry worked I stopped getting this issue.

context: https://github.com/learnables/learn2learn/issues/263

Dharman
  • 30,962
  • 25
  • 85
  • 135
Charlie Parker
  • 5,884
  • 57
  • 198
  • 323
0

This worked for me. For more details, please see here.

def backward(self, gen_loss=None, dis_loss=None):
    if dis_loss is not None:
        dis_loss.backward(retain_graph=True)  # modified here
    self.dis_optimizer.step()

    if gen_loss is not None:
        gen_loss.backward()
    self.gen_optimizer.step()
Peter Pack
  • 81
  • 3