Currently, this is what I have for a cGAN model using model subclassing (ignore that the discriminator model is missing).
# Create the generator.
generator = keras.Sequential(
[
keras.layers.InputLayer((generator_in_channels,)),
# We want to generate 128 + num_classes coefficients to reshape into a
# 7x7x(128 + num_classes) map.
layers.Dense(7 * 7 * generator_in_channels),
layers.LeakyReLU(alpha=0.2),
layers.Reshape((7, 7, generator_in_channels)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
class ConditionalGAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
@property
def metrics(self):
return [self.gen_loss_tracker, self.disc_loss_tracker]
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, data):
# Unpack the data.
real_images, one_hot_labels = data
# Add dummy dimensions to the labels so that they can be concatenated with
# the images. This is for the discriminator.
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = tf.repeat(
image_one_hot_labels, repeats=[image_size * image_size]
)
image_one_hot_labels = tf.reshape(
image_one_hot_labels, (-1, image_size, image_size, num_classes)
)
# Sample random points in the latent space and concatenate the labels.
# This is for the generator.
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
random_vector_labels = tf.concat(
[random_latent_vectors, one_hot_labels], axis=1
)
# Decode the noise (guided by labels) to fake images.
generated_images = self.generator(random_vector_labels)
# Combine them with real images. Note that we are concatenating the labels
# with these images here.
fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
combined_images = tf.concat(
[fake_image_and_labels, real_image_and_labels], axis=0
)
# Assemble labels discriminating real from fake images.
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Train the discriminator.
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space.
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
random_vector_labels = tf.concat(
[random_latent_vectors, one_hot_labels], axis=1
)
# Assemble labels that say "all real images".
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
fake_images = self.generator(random_vector_labels)
fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
predictions = self.discriminator(fake_image_and_labels)
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Monitor loss.
self.gen_loss_tracker.update_state(g_loss)
self.disc_loss_tracker.update_state(d_loss)
return {
"g_loss": self.gen_loss_tracker.result(),
"d_loss": self.disc_loss_tracker.result(),
}
However, for the inputs into the generator (a latent vector + label vector), I would first like to transform each of those into their own dense layer and then reshape, and then concatenate the reshaped layers. Then, this concatenated layer would be the input to the generator. This is how I would do it with the Functional API (see below). My question is, can I just replace the generator with this code instead or does the class ConditionalGAN(keras.Model) require that the generator model be of sequential form (as shown above)?
def label_conditioned_generator(n_classes=3, embedding_dim=100):
# embedding for categorical input
label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)
#print(label_embedding)
# linear multiplication
nodes = 4 * 4
label_dense = layers.Dense(nodes)(label_embedding)
# reshape to additional channel
label_reshape_layer = layers.Reshape((4, 4, 1))(label_dense)
return label_reshape_layer
def latent_input(latent_dim=100):
# image generator input
nodes = 512 * 4 * 4
latent_dense = layers.Dense(nodes)(latent_vector)
latent_dense = layers.ReLU()(latent_dense)
latent_reshape = layers.Reshape((4, 4, 512))(latent_dense)
return latent_reshape
# define the final generator model
def define_generator():
latent_vector_output = label_conditioned_generator()
label_output = latent_input()
# merge label_conditioned_generator and latent_input output
merge = layers.Concatenate()([latent_vector_output, label_output])
x = layers.Conv2DTranspose(64 * 8, kernel_size=4, strides= 2, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=0.02), use_bias=False, name='conv_transpose_1')(merge)
x = layers.BatchNormalization(momentum=0.1, epsilon=0.8, center=1.0, scale=0.02, name='bn_1')(x)
x = layers.ReLU(name='relu_1')(x)
x = layers.Conv2DTranspose(64 * 4, kernel_size=4, strides= 2, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=0.02), use_bias=False, name='conv_transpose_2')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=0.8, center=1.0, scale=0.02, name='bn_2')(x)
x = layers.ReLU(name='relu_2')(x)
x = layers.Conv2DTranspose(64 * 2, 4, 2, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=0.02), use_bias=False, name='conv_transpose_3')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=0.8, center=1.0, scale=0.02, name='bn_3')(x)
x = layers.ReLU(name='relu_3')(x)
x = layers.Conv2DTranspose(64 * 1, 4, 2, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=0.02), use_bias=False, name='conv_transpose_4')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=0.8, center=1.0, scale=0.02, name='bn_4')(x)
x = layers.ReLU(name='relu_4')(x)
out_layer = layers.Conv2DTranspose(3, 4, 2,padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=0.02), use_bias=False, activation='tanh', name='conv_transpose_6')(x)
# define model
model = tf.keras.Model([con_label, latent_vector], out_layer)
return model