1

I am trying to train GAN with pix2pix GAN generator and Unet as discriminator. But after some epochs my discriminator loss stop changing and stuck at value around 5.546. Is it good sign or bad sign for GAN training.

This is my loss calculation:

def discLoss(rValid, rLabel, fValid, fLabel):
    # validity loss
    bce =     tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
    # classifier loss
    scce =     tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    # Loss for real
    real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
    # Loss for fake
    fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
    # Total discriminator loss
    d_loss = (real_dloss + fake_dloss)# / 2
    return d_loss

def generator_loss(disc_generated_output, gen_output, target):
  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  LAMBDA = 100
  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

This is my train step:

def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    fImg = generator([img1, label], training=True)
    rValid, rLabel = discriminator(img2, training=True)
    fValid, fLabel = discriminator(fImg, training=True)

    disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
    gen_loss = generator_loss(fValid, fImg, img2)
    # genLoss(label, rValid, rLabel, fValid, fLabel)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()
Poe Dator
  • 4,535
  • 2
  • 14
  • 35
Girish Patel
  • 11
  • 1
  • 5

1 Answers1

1

This loss is too high. You need to watch that both G and D learn at even pace. Visit this question and related links there: How to balance the generator and the discriminator performances in a GAN?

Poe Dator
  • 4,535
  • 2
  • 14
  • 35