I would like to accelerate my WGAN-code written in Pytorch. In pseudocode, it looks like this:
n_times_critic = 5
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(batches):
z_fake = gen(noise)
z_real = batch
real_score = crit(z_real)
fake_score = crit(z_fake.detach())
c_loss = torch.mean(fake_score) - torch.mean(real_score)
# backprop the critic loss
c_loss.backward()
# update critic optimizer
crit_optim.step()
if batch_idx % n_times_critic == 0:
fake_score = crit(gen(noise))
g_loss = - torch.mean(fake_score)
# backprop the generator loss
g_loss.backward()
# update generator optimizer
gen_optim.step()
I have changed it to the following:
n_times_critic = 5
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(batches):
with torch.cuda.amp.autocast():
z_fake = gen(noise)
z_real = batch
real_score = crit(z_real)
fake_score = crit(z_fake.detach())
c_loss = torch.mean(fake_score) - torch.mean(real_score)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
crit_scaler.scale(c_loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
crit_scaler.unscale_(crit_optim)
crit_scaler.step(crit_optim)
crit_scaler.update()
if batch_idx % n_times_critic == 0:
with torch.cuda.amp.autocast():
fake_score = crit(gen(noise))
g_loss = - torch.mean(fake_score)
gen_scaler.scale(g_loss).backward()
gen_scaler.unscale_(gen_optim)
gen_scaler.step(gen_optim)
gen_scaler.update()
While the code without automatic mixed precision worked perfectly fine, using AMP causes the loss to become NAN. As recommended in the docs, I switched off the Gradscaler (https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#advanced-topics) to diagnose the problem. Now, there are no longer any NAN values in the loss, but I don't know what this means for me. Is there a bug in AMP or did I make a mistake in my implementation?