8

I'm building a custom model (SegNet) in Tensorflow 2.1.0.

The first problem I'm facing is the reutilization of the indices of the max pooling operation needed as described in the paper. Basically, since it is an encoder-decoder architecture, the pooling indices, of the encoding section of the network, are needed in the decoding to upsample the feature maps and keep the values targeted by the corresponding indices.

Now, in TF these indices are not exported by default by the layer tf.keras.layers.MaxPool2D (as for example are in PyTorch). To get the indices of the max pooling operation it is required to use tf.nn.max_pool_with_argmax. This operation, anyway, returns the indices (argmax) in a flattened format, which requires further operations to be useful in other parts of the network.

To implement a layer that performs a MaxPooling2D and exports these indices (flattened) I defined a custom layer in keras.

class MaxPoolingWithArgmax2D(Layer):

def __init__(
        self,
        pool_size=(2, 2),
        strides=2,
        padding='same',
        **kwargs):
    super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
    self.padding = padding
    self.pool_size = pool_size
    self.strides = strides

def call(self, inputs, **kwargs):
    padding = self.padding
    pool_size = self.pool_size
    strides = self.strides
    output, argmax = tf.nn.max_pool_with_argmax(
        inputs,
        ksize=pool_size,
        strides=strides,
        padding=padding.upper(),
        output_dtype=tf.int64)
    return output, argmax

Obviously, this layer is used in the encoding section of the network, hence a decoding respective layer is needed to perform the inverse operation (UpSampling2D), with the utilization of the indices (further details of this operation in the paper).

After some research, I found legacy code (TF<2.1.0) and adapted it to perform the operation. Anyway I'm not 100% convinced this code works well, in fact there are some things I don't like.

class MaxUnpooling2D(Layer):
def __init__(self, size=(2, 2), **kwargs):
    super(MaxUnpooling2D, self).__init__(**kwargs)
    self.size = size

def call(self, inputs, output_shape=None):
    updates, mask = inputs[0], inputs[1]
    with tf.name_scope(self.name):
        mask = tf.cast(mask, 'int32')
        #input_shape = tf.shape(updates, out_type='int32')
        input_shape = updates.get_shape()

        # This statement is required if I don't want to specify a batch size
        if input_shape[0] == None:
            batches = 1
        else:
            batches = input_shape[0]

        #  calculation new shape
        if output_shape is None:
            output_shape = (
                    batches,
                    input_shape[1]*self.size[0],
                    input_shape[2]*self.size[1],
                    input_shape[3])

        # calculation indices for batch, height, width and feature maps
        one_like_mask = tf.ones_like(mask, dtype='int32')
        batch_shape = tf.concat(
                [[batches], [1], [1], [1]],
                axis=0)
        batch_range = tf.reshape(
                tf.range(output_shape[0], dtype='int32'),
                shape=batch_shape)
        b = one_like_mask * batch_range
        y = mask // (output_shape[2] * output_shape[3])
        x = (mask // output_shape[3]) % output_shape[2]
        feature_range = tf.range(output_shape[3], dtype='int32')
        f = one_like_mask * feature_range

        # transpose indices & reshape update values to one dimension
        updates_size = tf.size(updates)
        indices = tf.transpose(tf.reshape(
            tf.stack([b, y, x, f]),
            [4, updates_size]))
        values = tf.reshape(updates, [updates_size])
        ret = tf.scatter_nd(indices, values, output_shape)
        return ret

The things that bother me are:

  1. Performing the operation to unflatten the indices (MaxUnpooling2D) is strictly related to knowing a specific batch size, which for model validation I would like to be None or unspecified.
  2. I am not sure this code is actually 100% compatible with the rest of the library. In fact during fit if I use tf.keras.metrics.MeanIoU the value converges to 0.341 and keeps constant for every other epoch than the first. Instead the standard accuracy metric works just fine.

Network architecture in Depth


Following, the complete definition of the model.

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tensorflow.keras.layers import Layer


class SegNet:
    def __init__(self, data_shape, classes = 3, batch_size = None):
        self.MODEL_NAME = 'SegNet'
        self.MODEL_VERSION = '0.2'

        self.classes = classes
        self.batch_size = batch_size

        self.build_model(data_shape)

    def build_model(self, data_shape):
        input_shape = (data_shape, data_shape, 3)

        inputs = keras.Input(shape=input_shape, batch_size=self.batch_size, name='Input')

        # Build sequential model

        # Encoding
        encoders = 5
        feature_maps = [64, 128, 256, 512, 512]
        n_convolutions = [2, 2, 3, 3, 3]
        eb_input = inputs
        eb_argmax_indices = []
        for encoder_index in range(encoders):
            encoder_block, argmax_indices = self.encoder_block(
                eb_input, encoder_index, feature_maps[encoder_index], n_convolutions[encoder_index])
            eb_argmax_indices.append(argmax_indices)
            eb_input = encoder_block

        # Decoding
        decoders = encoders
        db_input = encoder_block
        eb_argmax_indices.reverse()
        feature_maps.reverse()
        n_convolutions.reverse()
        d_feature_maps = [512, 512, 256, 128, 64]
        d_n_convolutions = n_convolutions
        for decoder_index in range(decoders):
            decoder_block = self.decoder_block(
                db_input, eb_argmax_indices[decoder_index], decoder_index, d_feature_maps[decoder_index], d_n_convolutions[decoder_index])
            db_input = decoder_block

        output = layers.Softmax()(decoder_block)

        self.model = keras.Model(inputs=inputs, outputs=output, name="SegNet")

    def encoder_block(self, x, encoder_index, feature_maps, n_convolutions):
        bank_input = x
        for conv_index in range(n_convolutions):
            bank = self.eb_layers_bank(
                bank_input, conv_index, feature_maps, encoder_index)
            bank_input = bank

        max_pool, indices = MaxPoolingWithArgmax2D(pool_size=(
            2, 2), strides=2, padding='same', name='EB_{}_MPOOL'.format(encoder_index + 1))(bank)

        return max_pool, indices

    def eb_layers_bank(self, x, bank_index, feature_maps, encoder_index):

        bank_input = x

        conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='EB_{}_BANK_{}_CONV'.format(
            encoder_index + 1, bank_index + 1))(bank_input)
        batch_norm = layers.BatchNormalization(
            name='EB_{}_BANK_{}_BN'.format(encoder_index + 1, bank_index + 1))(conv_l)
        relu = layers.ReLU(name='EB_{}_BANK_{}_RL'.format(
            encoder_index + 1, bank_index + 1))(batch_norm)

        return relu

    def decoder_block(self, x, max_pooling_idices, decoder_index, feature_maps, n_convolutions):
        #bank_input = self.unpool_with_argmax(x, max_pooling_idices)
        bank_input = MaxUnpooling2D(name='DB_{}_UPSAMP'.format(decoder_index + 1))([x, max_pooling_idices])
        #bank_input = layers.UpSampling2D()(x)
        for conv_index in range(n_convolutions):
            if conv_index == n_convolutions - 1:
                last_l_banck = True
            else:
                last_l_banck = False
            bank = self.db_layers_bank(
                bank_input, conv_index, feature_maps, decoder_index, last_l_banck)
            bank_input = bank

        return bank

    def db_layers_bank(self, x, bank_index, feature_maps, decoder_index, last_l_bank):
        bank_input = x

        if (last_l_bank) & (decoder_index == 4):
            conv_l = layers.Conv2D(self.classes, (1, 1), padding='same', name='DB_{}_BANK_{}_CONV'.format(
                decoder_index + 1, bank_index + 1))(bank_input)
            #batch_norm = layers.BatchNormalization(
            #    name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
            return conv_l
        else:

            if (last_l_bank) & (decoder_index > 0):
                conv_l = layers.Conv2D(int(feature_maps / 2), (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
                    decoder_index + 1, bank_index + 1))(bank_input)
            else:
                conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
                    decoder_index + 1, bank_index + 1))(bank_input)
            batch_norm = layers.BatchNormalization(
                name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
            relu = layers.ReLU(name='DB_{}_BANK_{}_RL'.format(
                decoder_index + 1, bank_index + 1))(batch_norm)

            return relu

    def get_model(self):
        return self.model

Here the output of model.summary().

Model: "SegNet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input (InputLayer)              [(None, 416, 416, 3) 0                                            
__________________________________________________________________________________________________
EB_1_BANK_1_CONV (Conv2D)       (None, 416, 416, 64) 1792        Input[0][0]                      
__________________________________________________________________________________________________
EB_1_BANK_1_BN (BatchNormalizat (None, 416, 416, 64) 256         EB_1_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_1_BANK_1_RL (ReLU)           (None, 416, 416, 64) 0           EB_1_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_1_BANK_2_CONV (Conv2D)       (None, 416, 416, 64) 36928       EB_1_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_1_BANK_2_BN (BatchNormalizat (None, 416, 416, 64) 256         EB_1_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_1_BANK_2_RL (ReLU)           (None, 416, 416, 64) 0           EB_1_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_1_MPOOL (MaxPoolingWithArgma ((None, 208, 208, 64 0           EB_1_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_2_BANK_1_CONV (Conv2D)       (None, 208, 208, 128 73856       EB_1_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_2_BANK_1_BN (BatchNormalizat (None, 208, 208, 128 512         EB_2_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_2_BANK_1_RL (ReLU)           (None, 208, 208, 128 0           EB_2_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_2_BANK_2_CONV (Conv2D)       (None, 208, 208, 128 147584      EB_2_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_2_BANK_2_BN (BatchNormalizat (None, 208, 208, 128 512         EB_2_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_2_BANK_2_RL (ReLU)           (None, 208, 208, 128 0           EB_2_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_2_MPOOL (MaxPoolingWithArgma ((None, 104, 104, 12 0           EB_2_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_1_CONV (Conv2D)       (None, 104, 104, 256 295168      EB_2_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_3_BANK_1_BN (BatchNormalizat (None, 104, 104, 256 1024        EB_3_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_3_BANK_1_RL (ReLU)           (None, 104, 104, 256 0           EB_3_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_2_CONV (Conv2D)       (None, 104, 104, 256 590080      EB_3_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_2_BN (BatchNormalizat (None, 104, 104, 256 1024        EB_3_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_3_BANK_2_RL (ReLU)           (None, 104, 104, 256 0           EB_3_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_3_CONV (Conv2D)       (None, 104, 104, 256 590080      EB_3_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_3_BANK_3_BN (BatchNormalizat (None, 104, 104, 256 1024        EB_3_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
EB_3_BANK_3_RL (ReLU)           (None, 104, 104, 256 0           EB_3_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
EB_3_MPOOL (MaxPoolingWithArgma ((None, 52, 52, 256) 0           EB_3_BANK_3_RL[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_1_CONV (Conv2D)       (None, 52, 52, 512)  1180160     EB_3_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_4_BANK_1_BN (BatchNormalizat (None, 52, 52, 512)  2048        EB_4_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_4_BANK_1_RL (ReLU)           (None, 52, 52, 512)  0           EB_4_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_2_CONV (Conv2D)       (None, 52, 52, 512)  2359808     EB_4_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_2_BN (BatchNormalizat (None, 52, 52, 512)  2048        EB_4_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_4_BANK_2_RL (ReLU)           (None, 52, 52, 512)  0           EB_4_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_3_CONV (Conv2D)       (None, 52, 52, 512)  2359808     EB_4_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_4_BANK_3_BN (BatchNormalizat (None, 52, 52, 512)  2048        EB_4_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
EB_4_BANK_3_RL (ReLU)           (None, 52, 52, 512)  0           EB_4_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
EB_4_MPOOL (MaxPoolingWithArgma ((None, 26, 26, 512) 0           EB_4_BANK_3_RL[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_1_CONV (Conv2D)       (None, 26, 26, 512)  2359808     EB_4_MPOOL[0][0]                 
__________________________________________________________________________________________________
EB_5_BANK_1_BN (BatchNormalizat (None, 26, 26, 512)  2048        EB_5_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
EB_5_BANK_1_RL (ReLU)           (None, 26, 26, 512)  0           EB_5_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_2_CONV (Conv2D)       (None, 26, 26, 512)  2359808     EB_5_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_2_BN (BatchNormalizat (None, 26, 26, 512)  2048        EB_5_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
EB_5_BANK_2_RL (ReLU)           (None, 26, 26, 512)  0           EB_5_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_3_CONV (Conv2D)       (None, 26, 26, 512)  2359808     EB_5_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
EB_5_BANK_3_BN (BatchNormalizat (None, 26, 26, 512)  2048        EB_5_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
EB_5_BANK_3_RL (ReLU)           (None, 26, 26, 512)  0           EB_5_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
EB_5_MPOOL (MaxPoolingWithArgma ((None, 13, 13, 512) 0           EB_5_BANK_3_RL[0][0]             
__________________________________________________________________________________________________
DB_1_UPSAMP (MaxUnpooling2D)    (1, 26, 26, 512)     0           EB_5_MPOOL[0][0]                 
                                                                 EB_5_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_1_BANK_1_CONV (Conv2D)       (1, 26, 26, 512)     2359808     DB_1_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_1_BANK_1_BN (BatchNormalizat (1, 26, 26, 512)     2048        DB_1_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_1_BANK_1_RL (ReLU)           (1, 26, 26, 512)     0           DB_1_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_2_CONV (Conv2D)       (1, 26, 26, 512)     2359808     DB_1_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_2_BN (BatchNormalizat (1, 26, 26, 512)     2048        DB_1_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_1_BANK_2_RL (ReLU)           (1, 26, 26, 512)     0           DB_1_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_3_CONV (Conv2D)       (1, 26, 26, 512)     2359808     DB_1_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
DB_1_BANK_3_BN (BatchNormalizat (1, 26, 26, 512)     2048        DB_1_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
DB_1_BANK_3_RL (ReLU)           (1, 26, 26, 512)     0           DB_1_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
DB_2_UPSAMP (MaxUnpooling2D)    (1, 52, 52, 512)     0           DB_1_BANK_3_RL[0][0]             
                                                                 EB_4_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_2_BANK_1_CONV (Conv2D)       (1, 52, 52, 512)     2359808     DB_2_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_2_BANK_1_BN (BatchNormalizat (1, 52, 52, 512)     2048        DB_2_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_2_BANK_1_RL (ReLU)           (1, 52, 52, 512)     0           DB_2_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_2_CONV (Conv2D)       (1, 52, 52, 512)     2359808     DB_2_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_2_BN (BatchNormalizat (1, 52, 52, 512)     2048        DB_2_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_2_BANK_2_RL (ReLU)           (1, 52, 52, 512)     0           DB_2_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_3_CONV (Conv2D)       (1, 52, 52, 256)     1179904     DB_2_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
DB_2_BANK_3_BN (BatchNormalizat (1, 52, 52, 256)     1024        DB_2_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
DB_2_BANK_3_RL (ReLU)           (1, 52, 52, 256)     0           DB_2_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
DB_3_UPSAMP (MaxUnpooling2D)    (1, 104, 104, 256)   0           DB_2_BANK_3_RL[0][0]             
                                                                 EB_3_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_3_BANK_1_CONV (Conv2D)       (1, 104, 104, 256)   590080      DB_3_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_3_BANK_1_BN (BatchNormalizat (1, 104, 104, 256)   1024        DB_3_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_3_BANK_1_RL (ReLU)           (1, 104, 104, 256)   0           DB_3_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_2_CONV (Conv2D)       (1, 104, 104, 256)   590080      DB_3_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_2_BN (BatchNormalizat (1, 104, 104, 256)   1024        DB_3_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_3_BANK_2_RL (ReLU)           (1, 104, 104, 256)   0           DB_3_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_3_CONV (Conv2D)       (1, 104, 104, 128)   295040      DB_3_BANK_2_RL[0][0]             
__________________________________________________________________________________________________
DB_3_BANK_3_BN (BatchNormalizat (1, 104, 104, 128)   512         DB_3_BANK_3_CONV[0][0]           
__________________________________________________________________________________________________
DB_3_BANK_3_RL (ReLU)           (1, 104, 104, 128)   0           DB_3_BANK_3_BN[0][0]             
__________________________________________________________________________________________________
DB_4_UPSAMP (MaxUnpooling2D)    (1, 208, 208, 128)   0           DB_3_BANK_3_RL[0][0]             
                                                                 EB_2_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_4_BANK_1_CONV (Conv2D)       (1, 208, 208, 128)   147584      DB_4_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_4_BANK_1_BN (BatchNormalizat (1, 208, 208, 128)   512         DB_4_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_4_BANK_1_RL (ReLU)           (1, 208, 208, 128)   0           DB_4_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_4_BANK_2_CONV (Conv2D)       (1, 208, 208, 64)    73792       DB_4_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
DB_4_BANK_2_BN (BatchNormalizat (1, 208, 208, 64)    256         DB_4_BANK_2_CONV[0][0]           
__________________________________________________________________________________________________
DB_4_BANK_2_RL (ReLU)           (1, 208, 208, 64)    0           DB_4_BANK_2_BN[0][0]             
__________________________________________________________________________________________________
DB_5_UPSAMP (MaxUnpooling2D)    (1, 416, 416, 64)    0           DB_4_BANK_2_RL[0][0]             
                                                                 EB_1_MPOOL[0][1]                 
__________________________________________________________________________________________________
DB_5_BANK_1_CONV (Conv2D)       (1, 416, 416, 64)    36928       DB_5_UPSAMP[0][0]                
__________________________________________________________________________________________________
DB_5_BANK_1_BN (BatchNormalizat (1, 416, 416, 64)    256         DB_5_BANK_1_CONV[0][0]           
__________________________________________________________________________________________________
DB_5_BANK_1_RL (ReLU)           (1, 416, 416, 64)    0           DB_5_BANK_1_BN[0][0]             
__________________________________________________________________________________________________
DB_5_BANK_2_CONV (Conv2D)       (1, 416, 416, 3)     195         DB_5_BANK_1_RL[0][0]             
__________________________________________________________________________________________________
softmax (Softmax)               (1, 416, 416, 3)     0           DB_5_BANK_2_CONV[0][0]           
==================================================================================================
Total params: 29,459,075
Trainable params: 29,443,203
Non-trainable params: 15,872
__________________________________________________________________________________________________

As you can see, I'm forced to specify a batch size in the MaxUnpooling2D otherwise I get errors that the operation can not be performed since there are None values and shapes can not be correctly transformed.

When I try to predict an image, I'm forced to specify the correct batch dimension, otherwise I get errors like:

InvalidArgumentError:  Shapes of all inputs must match: values[0].shape = [4,208,208,64] != values[1].shape = [1,208,208,64]
     [[{{node SegNet/DB_5_UPSAMP/PartitionedCall/PartitionedCall/DB_5_UPSAMP/stack}}]] [Op:__inference_predict_function_70839]

Which is caused by the implementation required to unravel the indices from the max pooling operation.


Training graphs

Here is a reference with a training on 20 epochs.

As you can see the MeanIoU metric is linear, no progress, no updates other than in epoch 1. Mean intersection over union

The other metric works fine, and loss decrease correctly.

Loss and accuracy

––––––––––

Conclusions

  1. There is a better way, more compatible with recent versions of TF, to implement the unraveling and upsampling with indices from the max pooling operation?
  2. If the implementation is correct, why I get a metric stuck at a specific value? Am I doing something wrong in the model?

Thank you!

rpasianotto
  • 1,383
  • 1
  • 9
  • 22

1 Answers1

0

You can have reshapes with unknown batch size in custom layers in two ways.

If you know the rest of the shape, reshape using -1 as the batch size:

Suppose you know the size of your expected array:

import tensorflow.keras.backend as K
reshaped = K.reshape(original, (-1, x, y, channels))

Suppose you don't know the size, then use K.shape to get the shape as a tensor:

inputs_shape = K.shape(inputs)
batch_size = inputs_shape[:1]
x = inputs_shape[1:2]
y = inputs_shape[2:3]
ch = inputs_shape[3:]

#you can then concatenate these and operate them (notice I kept them as 1D vector, not as scalar)
newShape = K.concatenate([batch_size, x, y, ch]) #of course you will make your operations

Once I did my own version of a Segnet, I didn't use indices, but kept a one hot version. It's true that it takes extra operations, but it might work well:

def get_indices(original, unpooled):
    is_equal = K.equal(original, unpooled)
    return K.cast(is_equal, K.floatx())

previous_output = ...
pooled = MaxPooling2D()(previous_output)
unpooled = UpSampling2D()(pooled)

one_hot_indices = Lambda(get_indices)([previous_output, unpooled])

Then after an upsampling, I concatenate these indices and pass a new conv:

some_output = ...
upsampled = UpSampling2D()(some_output)
with_indices = Concatenate([upsampled, one_hot_indices])
upsampled = Conv2D(...)(with_indices)
Daniel Möller
  • 84,878
  • 18
  • 192
  • 214