1
class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1)
        self.build(input_shape=[None, 1])

    def call(self, inputs, **kwargs):
        return self.dense(inputs)

MyModel().summary()

enter image description here

The model plot does not work as well:

tf.keras.utils.plot_model(model, to_file='model_1.png', show_shapes=True)

I tried this code on several tensorflow versions 2.3.0, 2.3.1, and 2.4.1 and every time the output shape is multiple! Is it a bug? Any fix?

Innat
  • 16,113
  • 6
  • 53
  • 101
LearnToGrow
  • 1,656
  • 6
  • 30
  • 53

1 Answers1

1

It's not the bug. Generally, we can't assume anything about the structure of a subclassed Model. That's why you can't get output shape in .summary() in model Subclasses API same as Functional or Sequential API like.

But here is a workaround to achieve this. You can achieve this as the following method.

import tensorflow as tf 

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1)
        self.build(input_shape=[None, 1])

    def call(self, inputs, **kwargs):
        return self.dense(inputs)

    def build_graph(self):
        x = tf.keras.layers.Input(shape=(1))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

MyModel().build_graph().summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 2         
=================================================================
Total params: 2
Trainable params: 2
Non-trainable params: 0
_________________________________________________________________

Same as plotting the model.

tf.keras.utils.plot_model(
    MyModel().build_graph()                     
)
Innat
  • 16,113
  • 6
  • 53
  • 101