Models implemented as subclasses of keras. Model
can generally not be visualized with plot_model
. There is a workaround as described here. However, it only applies to simple models. As soon as a model is enclosed by another model, the nestings will not be resolved.
I am looking for a way to resolve nested models implemented as subclasses of the keras. Model
. As an example, I have created a minimal GAN model:
import keras
from keras import layers
from tensorflow.python.keras.utils.vis_utils import plot_model
class BaseModel(keras.Model):
def __init__(self, *args, **kwargs):
super(BaseModel, self).__init__(*args, **kwargs)
def call(self, inputs, training=None, mask=None):
super(BaseModel, self).call(inputs=inputs, training=training, mask=mask)
def get_config(self):
super(BaseModel, self).get_config()
def build_graph(self, raw_shape):
""" Plot models that subclass `keras.Model`
Adapted from https://stackoverflow.com/questions/61427583/how-do-i-plot-a-keras-tensorflow-subclassing-api-model
:param raw_shape: Shape tuple not containing the batch_size
:return:
"""
x = keras.Input(shape=raw_shape)
return keras.Model(inputs=[x], outputs=self.call(x))
class GANModel(BaseModel):
def __init__(self, generator, discriminator):
super(GANModel, self).__init__()
self.generator = generator
self.discriminator = discriminator
def call(self, input_tensor, training=False, mask=None):
x = self.generator(input_tensor)
x = self.discriminator(x)
return x
class DiscriminatorModel(BaseModel):
def __init__(self, name="Critic"):
super(DiscriminatorModel, self).__init__(name=name)
self.l1 = layers.Conv2D(64, 2, activation=layers.ReLU())
self.flat = layers.Flatten()
self.dense = layers.Dense(1)
def call(self, inputs, training=False, mask=None):
x = self.l1(inputs, training=training)
x = self.flat(x)
x = self.dense(x, training=training)
return x
class GeneratorModel(BaseModel):
def __init__(self, name="Generator"):
super(GeneratorModel, self).__init__(name=name)
self.dense = layers.Dense(128, activation=layers.ReLU())
self.reshape = layers.Reshape((7, 7, 128))
self.out = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")
def call(self, inputs, training=False, mask=None):
x = self.dense(inputs, training=training)
x = self.reshape(x)
x = self.out(x, training=training)
return x
g = GeneratorModel()
d = DiscriminatorModel()
plot_model(g.build_graph((7, 7, 1)), to_file="generator_model.png",
expand_nested=True, show_shapes=True)
gan = GANModel(generator=g, discriminator=d)
plot_model(gan.build_graph((7, 7, 1)), to_file="gan_model.png",
expand_nested=True, show_shapes=True)
Edit
Using the functional keras API I get the desired result (see here). The nested models are correctly resolved within the GAN model.
from keras import Model, layers, optimizers
from tensorflow.python.keras.utils.vis_utils import plot_model
def get_generator(input_dim):
initial = layers.Input(shape=input_dim)
x = layers.Dense(128, activation=layers.ReLU())(initial)
x = layers.Reshape((7, 7, 128))(x)
x = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")(x)
return Model(inputs=initial, outputs=x, name="Generator")
def get_discriminator(input_dim):
initial = layers.Input(shape=input_dim)
x = layers.Conv2D(64, 2, activation=layers.ReLU())(initial)
x = layers.Flatten()(x)
x = layers.Dense(1)(x)
return Model(inputs=initial, outputs=x, name="Discriminator")
def get_gan(input_dim, latent_dim):
initial = layers.Input(shape=input_dim)
x = get_generator(input_dim)(initial)
x = get_discriminator(latent_dim)(x)
return Model(inputs=initial, outputs=x, name="GAN")
m = get_generator((7, 7, 1))
m.compile(optimizer=optimizers.Adam())
plot_model(m, expand_nested=True, show_shapes=True, to_file="generator_model_functional.png")
gan = get_gan((7, 7, 1), (7, 7, 1))
plot_model(gan, expand_nested=True, show_shapes=True, to_file="gan_model_functional.png")