My issue:
I am trying to train a semantic segmentation model in tf.keras, in fact it works very well when I am using channels_last (WHC) mode (it reaches 96%+ val acc). I wanted to train it in channels_first (CHW) mode so the weights are compatible with TensorRT. When I do this, the ~80% training accuracy in the first few epochs dips down to around 0.020% and stays there permanently.
It is useful to know that the base of my model is a tf.keras.applications.MobileNet()
model with the pre-trained 'imagenet' weights. (Model architecture at the bottom.)
The transformation process:
I used the guidelines provided and I change only a few things here:
- Set
tf.keras.backend.set_image_data_format()
to 'channels_first'. - I change the channel order in the input tensor from:
input_tensor=Input(shape=(376, 672, 3))
to:input_tensor=Input(shape=(3, 376, 672))
- In my image preprocessing (using
tf.data.Dataset
), i usetf.transpose(img, perm=[2, 0, 1])
on both my input image and one-hot encoded mask to change the channel orders. I checked this with equality assertion to make sure its correct and it seems to be fine.
When I change these the training starts fine but as I said the training accuracy goes down to almost zero. When I revert back everything's fine again.
Possible leads:
What am I doing wrong or what could be the problematic part here? My suspicions are around these questions:
- Are the pre-trained imageNet weights changed to the 'channels_first' order also when I set the backend? Is this something I should consider at all?
- Could it be that the
tf.transpose()
function messes up the mask's one-hot encoding? (I have 3 classes represented by 3 colors: lane, opposing lane, background)
Maybe I am not seeing something obvious. I can provide further code and answers as needed.
EDIT:
08/17: This is still an ongoing issue, I have tried several things:
- I checked if the image and the mask is correct after the transpose with numpy assertion, seems correct.
- I suspected that the loss function calculates on the wrong axis, so I customized the loss function for the first axis (where the channels are). Here it is:
def ReverseAxisLoss(y_true, y_pred):
return K.categorical_crossentropy(y_true, y_pred, from_logits=True, axis=1)
- My main suspicion is that the 'channels first' backend setting does nothing to transpose the pretrained 'imagenet' weights for the mobilenet part. Is there an updated way for TF2.x / Keras to transpose the pre-trained weights into CHW format?
Here is the architecture that I use (the skipNet()
is the head network and the mobilenet is the base, and it is connected in the create_model()
function)
def skipNet(encoder_output, feed1, feed2, classes):
# random initializer and regularizer
stddev = 0.01
init = RandomNormal(stddev=stddev)
weight_decay = 1e-3
reg = l2(weight_decay)
score_feed2 = Conv2D(kernel_size=(1, 1), filters=classes, padding="SAME",
kernel_initializer=init, kernel_regularizer=reg)(feed2)
score_feed2_bn = BatchNormalization()(score_feed2)
score_feed1 = Conv2D(kernel_size=(1, 1), filters=classes, padding="SAME",
kernel_initializer=init, kernel_regularizer=reg)(feed1)
score_feed1_bn = BatchNormalization()(score_feed1)
upscore2 = Conv2DTranspose(kernel_size=(4, 4), filters=classes, strides=(2, 2),
padding="SAME", kernel_initializer=init,
kernel_regularizer=reg)(encoder_output)
height_pad1 = ZeroPadding2D(padding=((1,0),(0,0)))(upscore2)
upscore2_bn = BatchNormalization()(height_pad1)
fuse_feed1 = add([score_feed1_bn, upscore2_bn])
upscore4 = Conv2DTranspose(kernel_size=(4, 4), filters=classes, strides=(2, 2),
padding="SAME", kernel_initializer=init,
kernel_regularizer=reg)(fuse_feed1)
height_pad2 = ZeroPadding2D(padding=((0,1),(0,0)))(upscore4)
upscore4_bn = BatchNormalization()(height_pad2)
fuse_feed2 = add([score_feed2_bn, upscore4_bn])
upscore8 = Conv2DTranspose(kernel_size=(16, 16), filters=classes, strides=(8, 8),
padding="SAME", kernel_initializer=init,
kernel_regularizer=reg, activation="softmax")(fuse_feed2)
return upscore8
def create_model(classes):
base_model = tf.keras.applications.MobileNet(input_tensor=Input(shape=IMG_SHAPE),
include_top=False,
weights='imagenet')
conv4_2_output = base_model.get_layer(index=43).output
conv3_2_output = base_model.get_layer(index=30).output
conv_score_output = base_model.output
head_model = skipNet(conv_score_output, conv4_2_output, conv3_2_output, classes)
for layer in base_model.layers:
layer.trainable = False
model = Model(inputs=base_model.input, outputs=head_model)
return model