I am trying to train GAN with pix2pix GAN generator and Unet as discriminator. But after some epochs my discriminator loss stop changing and stuck at value around 5.546. Is it good sign or bad sign for GAN training.
This is my loss calculation:
def discLoss(rValid, rLabel, fValid, fLabel):
# validity loss
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
# classifier loss
scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Loss for real
real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
# Loss for fake
fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
# Total discriminator loss
d_loss = (real_dloss + fake_dloss)# / 2
return d_loss
def generator_loss(disc_generated_output, gen_output, target):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
LAMBDA = 100
# mean absolute error
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBDA * l1_loss)
return total_gen_loss
This is my train step:
def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
fImg = generator([img1, label], training=True)
rValid, rLabel = discriminator(img2, training=True)
fValid, fLabel = discriminator(fImg, training=True)
disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
gen_loss = generator_loss(fValid, fImg, img2)
# genLoss(label, rValid, rLabel, fValid, fLabel)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()