I'm building a custom model (SegNet) in Tensorflow 2.1.0
.
The first problem I'm facing is the reutilization of the indices of the max pooling operation needed as described in the paper. Basically, since it is an encoder-decoder architecture, the pooling indices, of the encoding section of the network, are needed in the decoding to upsample the feature maps and keep the values targeted by the corresponding indices.
Now, in TF these indices are not exported by default by the layer tf.keras.layers.MaxPool2D
(as for example are in PyTorch).
To get the indices of the max pooling operation it is required to use tf.nn.max_pool_with_argmax
.
This operation, anyway, returns the indices (argmax) in a flattened format, which requires further operations to be useful in other parts of the network.
To implement a layer that performs a MaxPooling2D and exports these indices (flattened) I defined a custom layer in keras.
class MaxPoolingWithArgmax2D(Layer):
def __init__(
self,
pool_size=(2, 2),
strides=2,
padding='same',
**kwargs):
super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
self.padding = padding
self.pool_size = pool_size
self.strides = strides
def call(self, inputs, **kwargs):
padding = self.padding
pool_size = self.pool_size
strides = self.strides
output, argmax = tf.nn.max_pool_with_argmax(
inputs,
ksize=pool_size,
strides=strides,
padding=padding.upper(),
output_dtype=tf.int64)
return output, argmax
Obviously, this layer is used in the encoding section of the network, hence a decoding respective layer is needed to perform the inverse operation (UpSampling2D), with the utilization of the indices (further details of this operation in the paper).
After some research, I found legacy code (TF<2.1.0) and adapted it to perform the operation. Anyway I'm not 100% convinced this code works well, in fact there are some things I don't like.
class MaxUnpooling2D(Layer):
def __init__(self, size=(2, 2), **kwargs):
super(MaxUnpooling2D, self).__init__(**kwargs)
self.size = size
def call(self, inputs, output_shape=None):
updates, mask = inputs[0], inputs[1]
with tf.name_scope(self.name):
mask = tf.cast(mask, 'int32')
#input_shape = tf.shape(updates, out_type='int32')
input_shape = updates.get_shape()
# This statement is required if I don't want to specify a batch size
if input_shape[0] == None:
batches = 1
else:
batches = input_shape[0]
# calculation new shape
if output_shape is None:
output_shape = (
batches,
input_shape[1]*self.size[0],
input_shape[2]*self.size[1],
input_shape[3])
# calculation indices for batch, height, width and feature maps
one_like_mask = tf.ones_like(mask, dtype='int32')
batch_shape = tf.concat(
[[batches], [1], [1], [1]],
axis=0)
batch_range = tf.reshape(
tf.range(output_shape[0], dtype='int32'),
shape=batch_shape)
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = (mask // output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype='int32')
f = one_like_mask * feature_range
# transpose indices & reshape update values to one dimension
updates_size = tf.size(updates)
indices = tf.transpose(tf.reshape(
tf.stack([b, y, x, f]),
[4, updates_size]))
values = tf.reshape(updates, [updates_size])
ret = tf.scatter_nd(indices, values, output_shape)
return ret
The things that bother me are:
- Performing the operation to unflatten the indices (MaxUnpooling2D) is strictly related to knowing a specific batch size, which for model validation I would like to be None or unspecified.
- I am not sure this code is actually 100% compatible with the rest of the library. In fact during
fit
if I usetf.keras.metrics.MeanIoU
the value converges to0.341
and keeps constant for every other epoch than the first. Instead the standard accuracy metric works just fine.
Network architecture in Depth
Following, the complete definition of the model.
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tensorflow.keras.layers import Layer
class SegNet:
def __init__(self, data_shape, classes = 3, batch_size = None):
self.MODEL_NAME = 'SegNet'
self.MODEL_VERSION = '0.2'
self.classes = classes
self.batch_size = batch_size
self.build_model(data_shape)
def build_model(self, data_shape):
input_shape = (data_shape, data_shape, 3)
inputs = keras.Input(shape=input_shape, batch_size=self.batch_size, name='Input')
# Build sequential model
# Encoding
encoders = 5
feature_maps = [64, 128, 256, 512, 512]
n_convolutions = [2, 2, 3, 3, 3]
eb_input = inputs
eb_argmax_indices = []
for encoder_index in range(encoders):
encoder_block, argmax_indices = self.encoder_block(
eb_input, encoder_index, feature_maps[encoder_index], n_convolutions[encoder_index])
eb_argmax_indices.append(argmax_indices)
eb_input = encoder_block
# Decoding
decoders = encoders
db_input = encoder_block
eb_argmax_indices.reverse()
feature_maps.reverse()
n_convolutions.reverse()
d_feature_maps = [512, 512, 256, 128, 64]
d_n_convolutions = n_convolutions
for decoder_index in range(decoders):
decoder_block = self.decoder_block(
db_input, eb_argmax_indices[decoder_index], decoder_index, d_feature_maps[decoder_index], d_n_convolutions[decoder_index])
db_input = decoder_block
output = layers.Softmax()(decoder_block)
self.model = keras.Model(inputs=inputs, outputs=output, name="SegNet")
def encoder_block(self, x, encoder_index, feature_maps, n_convolutions):
bank_input = x
for conv_index in range(n_convolutions):
bank = self.eb_layers_bank(
bank_input, conv_index, feature_maps, encoder_index)
bank_input = bank
max_pool, indices = MaxPoolingWithArgmax2D(pool_size=(
2, 2), strides=2, padding='same', name='EB_{}_MPOOL'.format(encoder_index + 1))(bank)
return max_pool, indices
def eb_layers_bank(self, x, bank_index, feature_maps, encoder_index):
bank_input = x
conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='EB_{}_BANK_{}_CONV'.format(
encoder_index + 1, bank_index + 1))(bank_input)
batch_norm = layers.BatchNormalization(
name='EB_{}_BANK_{}_BN'.format(encoder_index + 1, bank_index + 1))(conv_l)
relu = layers.ReLU(name='EB_{}_BANK_{}_RL'.format(
encoder_index + 1, bank_index + 1))(batch_norm)
return relu
def decoder_block(self, x, max_pooling_idices, decoder_index, feature_maps, n_convolutions):
#bank_input = self.unpool_with_argmax(x, max_pooling_idices)
bank_input = MaxUnpooling2D(name='DB_{}_UPSAMP'.format(decoder_index + 1))([x, max_pooling_idices])
#bank_input = layers.UpSampling2D()(x)
for conv_index in range(n_convolutions):
if conv_index == n_convolutions - 1:
last_l_banck = True
else:
last_l_banck = False
bank = self.db_layers_bank(
bank_input, conv_index, feature_maps, decoder_index, last_l_banck)
bank_input = bank
return bank
def db_layers_bank(self, x, bank_index, feature_maps, decoder_index, last_l_bank):
bank_input = x
if (last_l_bank) & (decoder_index == 4):
conv_l = layers.Conv2D(self.classes, (1, 1), padding='same', name='DB_{}_BANK_{}_CONV'.format(
decoder_index + 1, bank_index + 1))(bank_input)
#batch_norm = layers.BatchNormalization(
# name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
return conv_l
else:
if (last_l_bank) & (decoder_index > 0):
conv_l = layers.Conv2D(int(feature_maps / 2), (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
decoder_index + 1, bank_index + 1))(bank_input)
else:
conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
decoder_index + 1, bank_index + 1))(bank_input)
batch_norm = layers.BatchNormalization(
name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
relu = layers.ReLU(name='DB_{}_BANK_{}_RL'.format(
decoder_index + 1, bank_index + 1))(batch_norm)
return relu
def get_model(self):
return self.model
Here the output of model.summary()
.
Model: "SegNet"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input (InputLayer) [(None, 416, 416, 3) 0
__________________________________________________________________________________________________
EB_1_BANK_1_CONV (Conv2D) (None, 416, 416, 64) 1792 Input[0][0]
__________________________________________________________________________________________________
EB_1_BANK_1_BN (BatchNormalizat (None, 416, 416, 64) 256 EB_1_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
EB_1_BANK_1_RL (ReLU) (None, 416, 416, 64) 0 EB_1_BANK_1_BN[0][0]
__________________________________________________________________________________________________
EB_1_BANK_2_CONV (Conv2D) (None, 416, 416, 64) 36928 EB_1_BANK_1_RL[0][0]
__________________________________________________________________________________________________
EB_1_BANK_2_BN (BatchNormalizat (None, 416, 416, 64) 256 EB_1_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
EB_1_BANK_2_RL (ReLU) (None, 416, 416, 64) 0 EB_1_BANK_2_BN[0][0]
__________________________________________________________________________________________________
EB_1_MPOOL (MaxPoolingWithArgma ((None, 208, 208, 64 0 EB_1_BANK_2_RL[0][0]
__________________________________________________________________________________________________
EB_2_BANK_1_CONV (Conv2D) (None, 208, 208, 128 73856 EB_1_MPOOL[0][0]
__________________________________________________________________________________________________
EB_2_BANK_1_BN (BatchNormalizat (None, 208, 208, 128 512 EB_2_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
EB_2_BANK_1_RL (ReLU) (None, 208, 208, 128 0 EB_2_BANK_1_BN[0][0]
__________________________________________________________________________________________________
EB_2_BANK_2_CONV (Conv2D) (None, 208, 208, 128 147584 EB_2_BANK_1_RL[0][0]
__________________________________________________________________________________________________
EB_2_BANK_2_BN (BatchNormalizat (None, 208, 208, 128 512 EB_2_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
EB_2_BANK_2_RL (ReLU) (None, 208, 208, 128 0 EB_2_BANK_2_BN[0][0]
__________________________________________________________________________________________________
EB_2_MPOOL (MaxPoolingWithArgma ((None, 104, 104, 12 0 EB_2_BANK_2_RL[0][0]
__________________________________________________________________________________________________
EB_3_BANK_1_CONV (Conv2D) (None, 104, 104, 256 295168 EB_2_MPOOL[0][0]
__________________________________________________________________________________________________
EB_3_BANK_1_BN (BatchNormalizat (None, 104, 104, 256 1024 EB_3_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
EB_3_BANK_1_RL (ReLU) (None, 104, 104, 256 0 EB_3_BANK_1_BN[0][0]
__________________________________________________________________________________________________
EB_3_BANK_2_CONV (Conv2D) (None, 104, 104, 256 590080 EB_3_BANK_1_RL[0][0]
__________________________________________________________________________________________________
EB_3_BANK_2_BN (BatchNormalizat (None, 104, 104, 256 1024 EB_3_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
EB_3_BANK_2_RL (ReLU) (None, 104, 104, 256 0 EB_3_BANK_2_BN[0][0]
__________________________________________________________________________________________________
EB_3_BANK_3_CONV (Conv2D) (None, 104, 104, 256 590080 EB_3_BANK_2_RL[0][0]
__________________________________________________________________________________________________
EB_3_BANK_3_BN (BatchNormalizat (None, 104, 104, 256 1024 EB_3_BANK_3_CONV[0][0]
__________________________________________________________________________________________________
EB_3_BANK_3_RL (ReLU) (None, 104, 104, 256 0 EB_3_BANK_3_BN[0][0]
__________________________________________________________________________________________________
EB_3_MPOOL (MaxPoolingWithArgma ((None, 52, 52, 256) 0 EB_3_BANK_3_RL[0][0]
__________________________________________________________________________________________________
EB_4_BANK_1_CONV (Conv2D) (None, 52, 52, 512) 1180160 EB_3_MPOOL[0][0]
__________________________________________________________________________________________________
EB_4_BANK_1_BN (BatchNormalizat (None, 52, 52, 512) 2048 EB_4_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
EB_4_BANK_1_RL (ReLU) (None, 52, 52, 512) 0 EB_4_BANK_1_BN[0][0]
__________________________________________________________________________________________________
EB_4_BANK_2_CONV (Conv2D) (None, 52, 52, 512) 2359808 EB_4_BANK_1_RL[0][0]
__________________________________________________________________________________________________
EB_4_BANK_2_BN (BatchNormalizat (None, 52, 52, 512) 2048 EB_4_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
EB_4_BANK_2_RL (ReLU) (None, 52, 52, 512) 0 EB_4_BANK_2_BN[0][0]
__________________________________________________________________________________________________
EB_4_BANK_3_CONV (Conv2D) (None, 52, 52, 512) 2359808 EB_4_BANK_2_RL[0][0]
__________________________________________________________________________________________________
EB_4_BANK_3_BN (BatchNormalizat (None, 52, 52, 512) 2048 EB_4_BANK_3_CONV[0][0]
__________________________________________________________________________________________________
EB_4_BANK_3_RL (ReLU) (None, 52, 52, 512) 0 EB_4_BANK_3_BN[0][0]
__________________________________________________________________________________________________
EB_4_MPOOL (MaxPoolingWithArgma ((None, 26, 26, 512) 0 EB_4_BANK_3_RL[0][0]
__________________________________________________________________________________________________
EB_5_BANK_1_CONV (Conv2D) (None, 26, 26, 512) 2359808 EB_4_MPOOL[0][0]
__________________________________________________________________________________________________
EB_5_BANK_1_BN (BatchNormalizat (None, 26, 26, 512) 2048 EB_5_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
EB_5_BANK_1_RL (ReLU) (None, 26, 26, 512) 0 EB_5_BANK_1_BN[0][0]
__________________________________________________________________________________________________
EB_5_BANK_2_CONV (Conv2D) (None, 26, 26, 512) 2359808 EB_5_BANK_1_RL[0][0]
__________________________________________________________________________________________________
EB_5_BANK_2_BN (BatchNormalizat (None, 26, 26, 512) 2048 EB_5_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
EB_5_BANK_2_RL (ReLU) (None, 26, 26, 512) 0 EB_5_BANK_2_BN[0][0]
__________________________________________________________________________________________________
EB_5_BANK_3_CONV (Conv2D) (None, 26, 26, 512) 2359808 EB_5_BANK_2_RL[0][0]
__________________________________________________________________________________________________
EB_5_BANK_3_BN (BatchNormalizat (None, 26, 26, 512) 2048 EB_5_BANK_3_CONV[0][0]
__________________________________________________________________________________________________
EB_5_BANK_3_RL (ReLU) (None, 26, 26, 512) 0 EB_5_BANK_3_BN[0][0]
__________________________________________________________________________________________________
EB_5_MPOOL (MaxPoolingWithArgma ((None, 13, 13, 512) 0 EB_5_BANK_3_RL[0][0]
__________________________________________________________________________________________________
DB_1_UPSAMP (MaxUnpooling2D) (1, 26, 26, 512) 0 EB_5_MPOOL[0][0]
EB_5_MPOOL[0][1]
__________________________________________________________________________________________________
DB_1_BANK_1_CONV (Conv2D) (1, 26, 26, 512) 2359808 DB_1_UPSAMP[0][0]
__________________________________________________________________________________________________
DB_1_BANK_1_BN (BatchNormalizat (1, 26, 26, 512) 2048 DB_1_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
DB_1_BANK_1_RL (ReLU) (1, 26, 26, 512) 0 DB_1_BANK_1_BN[0][0]
__________________________________________________________________________________________________
DB_1_BANK_2_CONV (Conv2D) (1, 26, 26, 512) 2359808 DB_1_BANK_1_RL[0][0]
__________________________________________________________________________________________________
DB_1_BANK_2_BN (BatchNormalizat (1, 26, 26, 512) 2048 DB_1_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
DB_1_BANK_2_RL (ReLU) (1, 26, 26, 512) 0 DB_1_BANK_2_BN[0][0]
__________________________________________________________________________________________________
DB_1_BANK_3_CONV (Conv2D) (1, 26, 26, 512) 2359808 DB_1_BANK_2_RL[0][0]
__________________________________________________________________________________________________
DB_1_BANK_3_BN (BatchNormalizat (1, 26, 26, 512) 2048 DB_1_BANK_3_CONV[0][0]
__________________________________________________________________________________________________
DB_1_BANK_3_RL (ReLU) (1, 26, 26, 512) 0 DB_1_BANK_3_BN[0][0]
__________________________________________________________________________________________________
DB_2_UPSAMP (MaxUnpooling2D) (1, 52, 52, 512) 0 DB_1_BANK_3_RL[0][0]
EB_4_MPOOL[0][1]
__________________________________________________________________________________________________
DB_2_BANK_1_CONV (Conv2D) (1, 52, 52, 512) 2359808 DB_2_UPSAMP[0][0]
__________________________________________________________________________________________________
DB_2_BANK_1_BN (BatchNormalizat (1, 52, 52, 512) 2048 DB_2_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
DB_2_BANK_1_RL (ReLU) (1, 52, 52, 512) 0 DB_2_BANK_1_BN[0][0]
__________________________________________________________________________________________________
DB_2_BANK_2_CONV (Conv2D) (1, 52, 52, 512) 2359808 DB_2_BANK_1_RL[0][0]
__________________________________________________________________________________________________
DB_2_BANK_2_BN (BatchNormalizat (1, 52, 52, 512) 2048 DB_2_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
DB_2_BANK_2_RL (ReLU) (1, 52, 52, 512) 0 DB_2_BANK_2_BN[0][0]
__________________________________________________________________________________________________
DB_2_BANK_3_CONV (Conv2D) (1, 52, 52, 256) 1179904 DB_2_BANK_2_RL[0][0]
__________________________________________________________________________________________________
DB_2_BANK_3_BN (BatchNormalizat (1, 52, 52, 256) 1024 DB_2_BANK_3_CONV[0][0]
__________________________________________________________________________________________________
DB_2_BANK_3_RL (ReLU) (1, 52, 52, 256) 0 DB_2_BANK_3_BN[0][0]
__________________________________________________________________________________________________
DB_3_UPSAMP (MaxUnpooling2D) (1, 104, 104, 256) 0 DB_2_BANK_3_RL[0][0]
EB_3_MPOOL[0][1]
__________________________________________________________________________________________________
DB_3_BANK_1_CONV (Conv2D) (1, 104, 104, 256) 590080 DB_3_UPSAMP[0][0]
__________________________________________________________________________________________________
DB_3_BANK_1_BN (BatchNormalizat (1, 104, 104, 256) 1024 DB_3_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
DB_3_BANK_1_RL (ReLU) (1, 104, 104, 256) 0 DB_3_BANK_1_BN[0][0]
__________________________________________________________________________________________________
DB_3_BANK_2_CONV (Conv2D) (1, 104, 104, 256) 590080 DB_3_BANK_1_RL[0][0]
__________________________________________________________________________________________________
DB_3_BANK_2_BN (BatchNormalizat (1, 104, 104, 256) 1024 DB_3_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
DB_3_BANK_2_RL (ReLU) (1, 104, 104, 256) 0 DB_3_BANK_2_BN[0][0]
__________________________________________________________________________________________________
DB_3_BANK_3_CONV (Conv2D) (1, 104, 104, 128) 295040 DB_3_BANK_2_RL[0][0]
__________________________________________________________________________________________________
DB_3_BANK_3_BN (BatchNormalizat (1, 104, 104, 128) 512 DB_3_BANK_3_CONV[0][0]
__________________________________________________________________________________________________
DB_3_BANK_3_RL (ReLU) (1, 104, 104, 128) 0 DB_3_BANK_3_BN[0][0]
__________________________________________________________________________________________________
DB_4_UPSAMP (MaxUnpooling2D) (1, 208, 208, 128) 0 DB_3_BANK_3_RL[0][0]
EB_2_MPOOL[0][1]
__________________________________________________________________________________________________
DB_4_BANK_1_CONV (Conv2D) (1, 208, 208, 128) 147584 DB_4_UPSAMP[0][0]
__________________________________________________________________________________________________
DB_4_BANK_1_BN (BatchNormalizat (1, 208, 208, 128) 512 DB_4_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
DB_4_BANK_1_RL (ReLU) (1, 208, 208, 128) 0 DB_4_BANK_1_BN[0][0]
__________________________________________________________________________________________________
DB_4_BANK_2_CONV (Conv2D) (1, 208, 208, 64) 73792 DB_4_BANK_1_RL[0][0]
__________________________________________________________________________________________________
DB_4_BANK_2_BN (BatchNormalizat (1, 208, 208, 64) 256 DB_4_BANK_2_CONV[0][0]
__________________________________________________________________________________________________
DB_4_BANK_2_RL (ReLU) (1, 208, 208, 64) 0 DB_4_BANK_2_BN[0][0]
__________________________________________________________________________________________________
DB_5_UPSAMP (MaxUnpooling2D) (1, 416, 416, 64) 0 DB_4_BANK_2_RL[0][0]
EB_1_MPOOL[0][1]
__________________________________________________________________________________________________
DB_5_BANK_1_CONV (Conv2D) (1, 416, 416, 64) 36928 DB_5_UPSAMP[0][0]
__________________________________________________________________________________________________
DB_5_BANK_1_BN (BatchNormalizat (1, 416, 416, 64) 256 DB_5_BANK_1_CONV[0][0]
__________________________________________________________________________________________________
DB_5_BANK_1_RL (ReLU) (1, 416, 416, 64) 0 DB_5_BANK_1_BN[0][0]
__________________________________________________________________________________________________
DB_5_BANK_2_CONV (Conv2D) (1, 416, 416, 3) 195 DB_5_BANK_1_RL[0][0]
__________________________________________________________________________________________________
softmax (Softmax) (1, 416, 416, 3) 0 DB_5_BANK_2_CONV[0][0]
==================================================================================================
Total params: 29,459,075
Trainable params: 29,443,203
Non-trainable params: 15,872
__________________________________________________________________________________________________
As you can see, I'm forced to specify a batch size in the MaxUnpooling2D otherwise I get errors that the operation can not be performed since there are None
values and shapes can not be correctly transformed.
When I try to predict an image, I'm forced to specify the correct batch dimension, otherwise I get errors like:
InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [4,208,208,64] != values[1].shape = [1,208,208,64]
[[{{node SegNet/DB_5_UPSAMP/PartitionedCall/PartitionedCall/DB_5_UPSAMP/stack}}]] [Op:__inference_predict_function_70839]
Which is caused by the implementation required to unravel the indices from the max pooling operation.
Training graphs
Here is a reference with a training on 20 epochs.
As you can see the MeanIoU metric is linear, no progress, no updates other than in epoch 1.
The other metric works fine, and loss decrease correctly.
––––––––––
Conclusions
- There is a better way, more compatible with recent versions of TF, to implement the unraveling and upsampling with indices from the max pooling operation?
- If the implementation is correct, why I get a metric stuck at a specific value? Am I doing something wrong in the model?
Thank you!