1

I use the following example to make my question clear:


class Encoder(K.layers.Layer):
    def __init__(self, filters):
        super(Encoder, self).__init__()
        self.conv1 = Conv2D(filters=filters[0], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv2 = Conv2D(filters=filters[1], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv3 = Conv2D(filters=filters[2], kernel_size=3, strides=1, activation='relu', padding='same')
        self.pool = MaxPooling2D((2, 2), padding='same')
               
    
    def call(self, input_features):
        x = self.conv1(input_features)
        #print("Ex1", x.shape)
        x = self.pool(x)
        #print("Ex2", x.shape)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.pool(x)
        return x

class Decoder(K.layers.Layer):
    def __init__(self, filters):
        super(Decoder, self).__init__()
        self.conv1 = Conv2D(filters=filters[2], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv2 = Conv2D(filters=filters[1], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv3 = Conv2D(filters=filters[0], kernel_size=3, strides=1, activation='relu', padding='valid')
        self.conv4 = Conv2D(1, 3, 1, activation='sigmoid', padding='same')
        self.upsample = UpSampling2D((2, 2))
  
    def call(self, encoded):
        x = self.conv1(encoded)
        print("dx1", x.shape)
        x = self.upsample(x)
        #print("dx2", x.shape)
        x = self.conv2(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = self.upsample(x)
        return self.conv4(x)

class Autoencoder(K.Model):
    def __init__(self, filters):
        super(Autoencoder, self).__init__()
        self.loss = []
        self.encoder = Encoder(filters)
        self.decoder = Decoder(filters)

    def call(self, input_features):
        #print(input_features.shape)
        encoded = self.encoder(input_features)
        #print(encoded.shape)
        reconstructed = self.decoder(encoded)
        #print(reconstructed.shape)
        return reconstructed


max_epochs = 5
model = Autoencoder(filters)

model.compile(loss='binary_crossentropy', optimizer='adam')

loss = model.fit(x_train_noisy,
                x_train,
                validation_data=(x_test_noisy, x_test),
                epochs=max_epochs,
                batch_size=batch_size)

As you can see, the model that has created using some layers from keras.Layer, then if I want to show the model's architecture using the model.summary() function, I will have:

Model: "autoencoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 encoder (Encoder)           multiple                  14192     
                                                                 
 decoder (Decoder)           multiple                  16497     
                                                                 
=================================================================
Total params: 30,689
Trainable params: 30,689
Non-trainable params: 0

For me I want to have a more detailed description of the encoder layer and decoder layer. Any ideas?

feelfree
  • 11,175
  • 20
  • 96
  • 167

1 Answers1

1

The reason you're getting such output is because of using subclass API to build the model. It's a known issue that unlike sequential or functional API, the subclass API doesn't allow you build the model summary or plot function as it could be. Here is the two very relevant post exist.


However, in your case, you may need to change settings to make summary and plot_model useful. Those are

  1. Subclass keras.Model instead of keras.layers.Layer of encoder and decoder sub components.
  2. When you initialize layers in the init method, make sure those are in the same order of call method.

Encoder

Following 1 and 2 above here.

class Encoder(keras.Model):
    def __init__(self, filters):
        super().__init__(name='Encoder')
        self.conv1 = Conv2D(
            filters=filters[0], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.pool1 = MaxPooling2D((2, 2), padding='same')
        self.conv2 = Conv2D(
            filters=filters[1], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.pool2 = MaxPooling2D((2, 2), padding='same')
        self.conv3 = Conv2D(
            filters=filters[2], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.pool3 = MaxPooling2D((2, 2), padding='same')
        
    def call(self, input_features):
        x = self.conv1(input_features)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.pool3(x)
        return x

Decoder

Following 1 and 2 above here.

class Decoder(keras.Model):
    def __init__(self, filters):
        super().__init__(name='Decoder')
        self.conv1 = Conv2D(
            filters=filters[2], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.upsample1 = UpSampling2D((2, 2))
        self.conv2 = Conv2D(
            filters=filters[1], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.upsample2 = UpSampling2D((2, 2))
        self.conv3 = Conv2D(
            filters=filters[0],
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='valid'
        )
        self.upsample3 = UpSampling2D((2, 2))
        self.conv4 = Conv2D(1, 3, 1, activation='sigmoid', padding='same')
  
    def call(self, encoded):
        x = self.conv1(encoded)
        x = self.upsample1(x)
        x = self.conv2(x)
        x = self.upsample2(x)
        x = self.conv3(x)
        x = self.upsample3(x)
        return self.conv4(x)

Autoencoder

Because of following 1 and 2 above here, we will build the autoencoder as follows.

class Autoencoder(keras.Model):
    def __init__(self, filters):
        super().__init__(name='Autoencoder')
        self.encoder = Encoder(filters)
        self.decoder = Decoder(filters)

    def call(self, input_features):
        x = input_features
        
        for layer in self.encoder.layers:
            x = layer(x)
        
        for layer in self.decoder.layers: 
            x = layer(x)
        
        return x

Build Model


model = Autoencoder(filters)
model.build(input_shape=(1, 224, 224, 3))
model.summary(
    expand_nested=True, 
    line_length=80, show_trainable=True
)
Model: "Autoencoder"
___________________________________________________________________________________________
 Layer (type)                       Output Shape                    Param #     Trainable  
===========================================================================================
 Encoder (Encoder)                  multiple                        0 (unused)  Y          
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| conv2d_69 (Conv2D)               multiple                        3584        Y          |
|                                                                                         |
| max_pooling2d_21 (MaxPooling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_70 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| max_pooling2d_22 (MaxPooling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_71 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| max_pooling2d_23 (MaxPooling2D)  multiple                        0           Y          |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
 Decoder (Decoder)                  multiple                        0 (unused)  Y          
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| conv2d_72 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| up_sampling2d_9 (UpSampling2D)   multiple                        0           Y          |
|                                                                                         |
| conv2d_73 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| up_sampling2d_10 (UpSampling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_74 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| up_sampling2d_11 (UpSampling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_75 (Conv2D)               multiple                        1153        Y          |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
===========================================================================================
Total params: 742,657
Trainable params: 742,657
Non-trainable params: 0
___________________________________________________________________________________________

Great. But as you can see, in the summary, the Output Shape column is not informative. To fix it, we can use a class method (build_graph) as follows:


class Autoencoder(K.Model):
    def __init__(self, filters):
        super().__init__(name='Autoencoder')
        self.encoder = Encoder(filters)
        self.decoder = Decoder(filters)

    def call(self, input_features):
        x = input_features
        
        for layer in self.encoder.layers:
            x = layer(x)
        
        for layer in self.decoder.layers: 
            x = layer(x)
        
        return x
    
    def build_graph(self, input_shape):
        x = K.Input(shape=(input_shape))
        return K.Model(
            inputs=[x], outputs=self.call(x)
        )

Summary

model.build_graph(
    input_shape=(224, 224, 3)
).summary(expand_nested=True)
# OK

keras.utils.plot_model(
    model.build_graph(input_shape=(224, 224, 3)), 
    expand_nested=True,
    show_shapes=True,
    show_dtype=True, 
    show_layer_activations=True, 
    show_layer_names=True
)
# OK

That's it. However, if you think this should be supported out of the box, feel free to open ticket in keras-github.

Innat
  • 16,113
  • 6
  • 53
  • 101