6

Are these claims correct?

  • Model.trainable = False by itself has absolutely no effect (to anything compiled) unless compilation happens.
  • If I take two layers in ModelA which has been compiled (ModelA.compile(...)), create a skip model ModelB=Model(intermediate_layer1, intermediate_layer2) and set ModelB.trainable=False, ModelB.compile(...), nothing will change for ModelA; assuming that trainable has not been touched, everything in ModelA will have its weight updated if only ModelA is trained (ModelA.fit(...))
  • This only do with weight updates, so weights will be saved/loaded without problem (even if it's wrong weights).

It all started when I try to train my GAN, freezing discriminator when training generator and get this warning:

 UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?

I looked into this and found out that people had looked into this too:

https://github.com/keras-team/keras/issues/8585

Here's a reproducible example adapted from that Issue thread:

# making discriminator
d_input = Input(shape=(2,))
d_output = Activation('softmax')(Dense(2)(d_input))
discriminator = Model(inputs=d_input, outputs=d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

# making generator
g_input = Input(shape=(2,))
g_output = Activation('relu')(Dense(2)(g_input))
generator = Model(inputs=g_input, outputs=g_output)

# making gan(generator -> discriminator)
discriminator.trainable = False # CHECK THIS OUT!
gan = Model(inputs=g_input, outputs=discriminator(g_output))
gan.compile(loss='categorical_crossentropy', optimizer='adam')

# training
BATCH_SIZE = 3
some_input_data = np.array([[1,2],[3,4],[5,6]])
some_target_data = np.array([[1,1],[2,2],[3,3]])
# update discriminator
generated = generator.predict(some_input_data, verbose=0)
X = np.concatenate((some_target_data, generated), axis=0)
y = [[0,1]]*BATCH_SIZE + [[1,0]]*BATCH_SIZE
d_metrics = discriminator.train_on_batch(X, y)
# update generator
g_metrics = gan.train_on_batch(some_input_data, [[0,1]]*BATCH_SIZE)
# loop these operations for batches...

I got confused when some people say it's a false warning while some people say weights can be messed up.

Then I read this question: shouldn't model.trainable=False freeze weights under the model?

This post gave a good explanation on what "trainable" actually do. I would like to know whether my understanding is correct and ensure my GAN is training correctly.

xxbidiao
  • 834
  • 5
  • 14
  • 27

0 Answers0