I am trying to design a CNN that can do pixel wise segmentation of cell images. Such as these:
With segmentation masks such as this (except more than one segmentation mask for each raw image, eg: interior of cell, border of cell, background):
I have mostly copied the U-net design from here: https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/
However even 10 annotated images (over 300 cells) I still get quite bad dice coefficient scores and not great predictions. According to the U-Net paper this number of annotated cells should be sufficient for a good prediction.
This is the code for the model I am using.
def get_unet():
inputs = Input((img_rows, img_cols, 1))
conv1 = Conv2D(16, window_size, activation='relu', padding='same')(inputs)
conv1 = Conv2D(16, window_size, activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, window_size, activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, window_size, activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, window_size, activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, window_size, activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(128, window_size, activation='relu', padding='same')(pool3)
conv4 = Conv2D(128, window_size, activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(512, window_size, activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, window_size, activation='relu', padding='same')(conv5)
up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
conv6 = Conv2D(128, window_size, activation='relu', padding='same')(up6)
conv6 = Conv2D(128, window_size, activation='relu', padding='same')(conv6)
up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
conv7 = Conv2D(128, window_size, activation='relu', padding='same')(up7)
conv7 = Conv2D(128, window_size, activation='relu', padding='same')(conv7)
up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
conv8 = Conv2D(64, window_size, activation='relu', padding='same')(up8)
conv8 = Conv2D(64, window_size, activation='relu', padding='same')(conv8)
up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
conv9 = Conv2D(16, window_size, activation='relu', padding='same')(up9)
conv9 = Conv2D(16, window_size, activation='relu', padding='same')(conv9)
conv10 = Conv2D(f_num, (1, 1), activation='softmax')(conv9) # change to N,(1,1) for more classes and softmax
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
return model`
I have tried many different hyper-parameters for the model all with no success. Dice scores hover around 0.25 and my loss barely decreases between epochs. I feel I am doing something fundamentally wrong here. Any suggestions?
EDIT: Sigmoid activation improves dice score from 0.25 to 0.33 (again however 1 epoch reaches this score and subsequent epochs only improve very slightly from 0.33 to 0.331 etc)
dice_coef_loss is defined as below
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
Also in case it's useful the model.summary output:
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 64, 64, 1) 0
_________________________________________________________________
conv2d_20 (Conv2D) (None, 64, 64, 16) 32
_________________________________________________________________
conv2d_21 (Conv2D) (None, 64, 64, 16) 272
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 32, 32, 16) 0
_________________________________________________________________
conv2d_22 (Conv2D) (None, 32, 32, 64) 1088
_________________________________________________________________
conv2d_23 (Conv2D) (None, 32, 32, 64) 4160
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 16, 16, 64) 0
_________________________________________________________________
conv2d_24 (Conv2D) (None, 16, 16, 128) 8320
_________________________________________________________________
conv2d_25 (Conv2D) (None, 16, 16, 128) 16512
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 8, 8, 128) 0
_________________________________________________________________
conv2d_26 (Conv2D) (None, 8, 8, 128) 16512
_________________________________________________________________
conv2d_27 (Conv2D) (None, 8, 8, 128) 16512
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 4, 4, 128) 0
_________________________________________________________________
conv2d_28 (Conv2D) (None, 4, 4, 512) 66048
_________________________________________________________________
conv2d_29 (Conv2D) (None, 4, 4, 512) 262656
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 8, 8, 512) 1049088
_________________________________________________________________
concatenate_5 (Concatenate) (None, 8, 8, 640) 0
_________________________________________________________________
conv2d_30 (Conv2D) (None, 8, 8, 128) 82048
_________________________________________________________________
conv2d_31 (Conv2D) (None, 8, 8, 128) 16512
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 16, 16, 128) 65664
_________________________________________________________________
concatenate_6 (Concatenate) (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_32 (Conv2D) (None, 16, 16, 128) 32896
_________________________________________________________________
conv2d_33 (Conv2D) (None, 16, 16, 128) 16512
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 32, 32, 128) 65664
_________________________________________________________________
concatenate_7 (Concatenate) (None, 32, 32, 192) 0
_________________________________________________________________
conv2d_34 (Conv2D) (None, 32, 32, 64) 12352
_________________________________________________________________
conv2d_35 (Conv2D) (None, 32, 32, 64) 4160
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 64, 64, 64) 16448
_________________________________________________________________
concatenate_8 (Concatenate) (None, 64, 64, 80) 0
_________________________________________________________________
conv2d_36 (Conv2D) (None, 64, 64, 16) 1296
_________________________________________________________________
conv2d_37 (Conv2D) (None, 64, 64, 16) 272
_________________________________________________________________
conv2d_38 (Conv2D) (None, 64, 64, 4) 68
=================================================================
Total params: 1,755,092.0
Trainable params: 1,755,092.0
Non-trainable params: 0.0