2

I'm trying to build some model using TensorFlow2, so I create a class of my model as follows:

import tensorflow as tf

class Dummy(tf.keras.Model):
    def __init__(self, name="dummy"):
        super(Dummy, self).__init__()
        self._name = name

        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        return self.dense2(x)

model = Dummy()
model.build(input_shape=(None,5))

Now I want to plot the model, while using summary() returns what I expect, plot_model(model, show_shapes=True, expand_nested=True) return only a block with the model name.

How can I return the graph of my model?

Timbus Calin
  • 13,809
  • 5
  • 41
  • 59
Hagai Tz.
  • 80
  • 6

1 Answers1

5

Francois Chollet says the following:

You can do all these things (printing input / output shapes) in a Functional or Sequential model because these models are static graphs of layers.

In contrast, a subclassed model is a piece of Python code (a call method). There is no graph of layers here. We cannot know how layers are connected to each other (because that's defined in the body of call, not as an explicit data structure), so we cannot infer input / output shapes.

There are two solutions to this:

  1. Either you build your model sequentially/using the Functional api.
  2. You wrap your 'call' function into a Functional Model like here:

class Subclass(Model):

def __init__(self):
    ...
def call(self, x):
    ...

def model(self):
    x = Input(shape=(24, 24, 3))
    return Model(inputs=[x], outputs=self.call(x))


if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()

Answer is taken from here:model.summary() can't print output shape while using subclass model

Also, this is a good article read: https://medium.com/tensorflow/what-are-symbolic-and-imperative-apis-in-tensorflow-2-0-dfccecb01021

Timbus Calin
  • 13,809
  • 5
  • 41
  • 59