4

Models implemented as subclasses of keras. Model can generally not be visualized with plot_model. There is a workaround as described here. However, it only applies to simple models. As soon as a model is enclosed by another model, the nestings will not be resolved.

I am looking for a way to resolve nested models implemented as subclasses of the keras. Model. As an example, I have created a minimal GAN model:

import keras
from keras import layers
from tensorflow.python.keras.utils.vis_utils import plot_model


class BaseModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super(BaseModel, self).__init__(*args, **kwargs)

    def call(self, inputs, training=None, mask=None):
        super(BaseModel, self).call(inputs=inputs, training=training, mask=mask)

    def get_config(self):
        super(BaseModel, self).get_config()

    def build_graph(self, raw_shape):
        """ Plot models that subclass `keras.Model`

        Adapted from https://stackoverflow.com/questions/61427583/how-do-i-plot-a-keras-tensorflow-subclassing-api-model

        :param raw_shape: Shape tuple not containing the batch_size
        :return:
        """
        x = keras.Input(shape=raw_shape)
        return keras.Model(inputs=[x], outputs=self.call(x))


class GANModel(BaseModel):
    def __init__(self, generator, discriminator):
        super(GANModel, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def call(self, input_tensor, training=False, mask=None):
        x = self.generator(input_tensor)
        x = self.discriminator(x)
        return x


class DiscriminatorModel(BaseModel):
    def __init__(self, name="Critic"):
        super(DiscriminatorModel, self).__init__(name=name)
        self.l1 = layers.Conv2D(64, 2, activation=layers.ReLU())
        self.flat = layers.Flatten()
        self.dense = layers.Dense(1)

    def call(self, inputs, training=False, mask=None):
        x = self.l1(inputs, training=training)
        x = self.flat(x)
        x = self.dense(x, training=training)
        return x


class GeneratorModel(BaseModel):
    def __init__(self, name="Generator"):
        super(GeneratorModel, self).__init__(name=name)
        self.dense = layers.Dense(128, activation=layers.ReLU())
        self.reshape = layers.Reshape((7, 7, 128))
        self.out = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")

    def call(self, inputs, training=False, mask=None):
        x = self.dense(inputs, training=training)
        x = self.reshape(x)
        x = self.out(x, training=training)
        return x


g = GeneratorModel()
d = DiscriminatorModel()

plot_model(g.build_graph((7, 7, 1)), to_file="generator_model.png",
           expand_nested=True, show_shapes=True)

gan = GANModel(generator=g, discriminator=d)
plot_model(gan.build_graph((7, 7, 1)), to_file="gan_model.png", 
           expand_nested=True, show_shapes=True)

Edit

Using the functional keras API I get the desired result (see here). The nested models are correctly resolved within the GAN model.

from keras import Model, layers, optimizers
from tensorflow.python.keras.utils.vis_utils import plot_model


def get_generator(input_dim):
    initial = layers.Input(shape=input_dim)

    x = layers.Dense(128, activation=layers.ReLU())(initial)
    x = layers.Reshape((7, 7, 128))(x)
    x = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")(x)

    return Model(inputs=initial, outputs=x, name="Generator")


def get_discriminator(input_dim):
    initial = layers.Input(shape=input_dim)

    x = layers.Conv2D(64, 2, activation=layers.ReLU())(initial)
    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)

    return Model(inputs=initial, outputs=x, name="Discriminator")

def get_gan(input_dim, latent_dim):
    initial = layers.Input(shape=input_dim)

    x = get_generator(input_dim)(initial)
    x = get_discriminator(latent_dim)(x)

    return Model(inputs=initial, outputs=x, name="GAN")



m = get_generator((7, 7, 1))
m.compile(optimizer=optimizers.Adam())

plot_model(m, expand_nested=True, show_shapes=True, to_file="generator_model_functional.png")

gan = get_gan((7, 7, 1), (7, 7, 1))
plot_model(gan, expand_nested=True, show_shapes=True, to_file="gan_model_functional.png")
Molitoris
  • 935
  • 1
  • 9
  • 31

1 Answers1

1

Whenever you pass each generator and discriminator to GANModel, they act like an encompassed child layer consisting of n times layers. So, if you plot only the generator model by the GANModel instances, it will show as follows (same goes to discriminator) unlike plots while using them separately.

The fact is while we pass data at this point using the call() method of GANModel, the input passes implicitly all internal layers (generator, discriminator) according to its design. Here I will show you two workaround for this to get your desired plot.

enter image description here


Option 1

I believe you probably guess the method. In the GANModel model, we will pass the input very explicitly to each internal layer of those child layers (generator, discriminator).

class GANModel(BaseModel):
    def __init__(self, generator, discriminator):
        super(GANModel, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def call(self, input_tensor, training=False, mask=None):
        x = input_tensor

        for gen_lyr in self.generator.layers:
            print(gen_lyr) # checking 
            x = gen_lyr(x)

        for disc_lyr in self.discriminator.layers:
            print(disc_lyr) # checking 
            x = disc_lyr(x)

        return x

If you plot now, you will get

# All Internal Layers of self.generator, self.discriminator
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a472a3710>
<tensorflow.python.keras.layers.core.Reshape object at 0x7f2a461e8f50>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a44591f90>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a47317290>
<tensorflow.python.keras.layers.core.Flatten object at 0x7f2a47317ed0>
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a57f42910>

enter image description here


Option 2

I think it's a bit ugly approach. First, we take each internal layer and build a Sequential model with them. Then use .build to create its input layer. BOOM.

gan = GANModel(generator=g, discriminator=d)

all_layer = []
for layer in gan.layers: 
    all_layer.extend(layer.layers)

gan_plot = tf.keras.models.Sequential(all_layer)
gan_plot.build((None,7,7,1))
list(all_layer)

[<tensorflow.python.keras.layers.core.Dense at 0x7f2a461ab390>,
 <tensorflow.python.keras.layers.core.Reshape at 0x7f2a46156110>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461fedd0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461500d0>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f2a4613ea10>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f2a462cae10>]
tf.keras.utils.plot_model(gan_plot, expand_nested=True, show_shapes=True)
Innat
  • 16,113
  • 6
  • 53
  • 101
  • 1
    Thank youfor your reply. I appreciate your efforts. I am still not convinced as I am losing the structure for nested models. PS: I edited my question to show how smoothly it works with the functional API. – Molitoris Apr 01 '21 at 12:40
  • I see. Let me explain. In **functional API**, we usually use `layers. Input` which basically a **spec layer** rather than a real **trainable layer**. In your edited part, you used them (**3** times) to build each block of layers. And so the `keras` will plot according to this **model definition**. – Innat Apr 01 '21 at 13:31
  • Now, go back to your **subclassing API**, you won't see any `layers. Input`. But except in one place, in the `build_graph` method. The way you define your model definition in the `call` method, `keras` will plot in that way. – Innat Apr 01 '21 at 13:31