0

I would like to accelerate my WGAN-code written in Pytorch. In pseudocode, it looks like this:

n_times_critic = 5

for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(batches):
        z_fake = gen(noise)
        z_real = batch
        real_score = crit(z_real)
        fake_score = crit(z_fake.detach())
        c_loss = torch.mean(fake_score) - torch.mean(real_score)
        # backprop the critic loss
        c_loss.backward()
        # update critic optimizer
        crit_optim.step()

        if batch_idx % n_times_critic == 0:
            fake_score = crit(gen(noise))
            g_loss = - torch.mean(fake_score)
            # backprop the generator loss
            g_loss.backward()
            # update generator optimizer
            gen_optim.step()

I have changed it to the following:

n_times_critic = 5
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(batches):
        with torch.cuda.amp.autocast():
            z_fake = gen(noise)
            z_real = batch
            real_score = crit(z_real)
            fake_score = crit(z_fake.detach())
        c_loss = torch.mean(fake_score) - torch.mean(real_score)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        crit_scaler.scale(c_loss).backward()
        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        crit_scaler.unscale_(crit_optim)
        crit_scaler.step(crit_optim)
        crit_scaler.update()

        if batch_idx % n_times_critic == 0:
            with torch.cuda.amp.autocast():
                fake_score = crit(gen(noise))
                g_loss = - torch.mean(fake_score)
            gen_scaler.scale(g_loss).backward()
            gen_scaler.unscale_(gen_optim)
            gen_scaler.step(gen_optim)
            gen_scaler.update()

While the code without automatic mixed precision worked perfectly fine, using AMP causes the loss to become NAN. As recommended in the docs, I switched off the Gradscaler (https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#advanced-topics) to diagnose the problem. Now, there are no longer any NAN values in the loss, but I don't know what this means for me. Is there a bug in AMP or did I make a mistake in my implementation?

Ynjxsjmh
  • 28,441
  • 6
  • 34
  • 52

0 Answers0