0

I am using the unet for an image segmentation problem, the network trains very well when I use the Dice loss, but it does not optimize for any order of magnitude of the learning rate when I use the binary cross entropy or the weighted cross entropy with logits. This is the model I am using

def unet_no_dropout(pretrained_weights=None, input_size=(512, 512, 1), act = 'elu'):
        inputs = tf.keras.layers.Input(input_size)
        conv1 = tf.keras.layers.Conv2D(64, 3 , activation=act,padding='same', kernel_initializer='he_normal')(inputs)
        #conv1 = tf.keras.layers.Activation(swish)(conv1)
        conv1 = tf.keras.layers.Conv2D(64, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv1)
        pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
        conv2 = tf.keras.layers.Conv2D(128, 3, activation=act, padding='same', kernel_initializer='he_normal')(pool1)
        conv2 = tf.keras.layers.Conv2D(128, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv2)
        pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
        conv3 = tf.keras.layers.Conv2D(256, 3, activation=act, padding='same', kernel_initializer='he_normal')(pool2)
        conv3 = tf.keras.layers.Conv2D(256, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv3)
        pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)
        conv4 = tf.keras.layers.Conv2D(512, 3, activation=act, padding='same', kernel_initializer='he_normal')(pool3)
        conv4 = tf.keras.layers.Conv2D(512, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv4)
        #drop4 = tf.keras.layers.Dropout(0.5)(conv4)
        pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv4)
    
        conv5 = tf.keras.layers.Conv2D(1024, 3, activation=act, padding='same', kernel_initializer='he_normal')(pool4)
        conv5 = tf.keras.layers.Conv2D(1024, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv5)
        #drop5 = tf.keras.layers.Dropout(0.5)(conv5)
    
        up6 = tf.keras.layers.Conv2D(512, 2, activation=act, padding='same', kernel_initializer='he_normal')(
            tf.keras.layers.UpSampling2D(size=(2, 2))(conv5))
        merge6 = tf.keras.layers.concatenate([conv4, up6], axis=3)
        #merge6 = tf.keras.layers.concatenate([conv4, up6], axis=3)
        conv6 = tf.keras.layers.Conv2D(512, 3, activation=act, padding='same', kernel_initializer='he_normal')(merge6)
        conv6 = tf.keras.layers.Conv2D(512, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv6)
    
        up7 = tf.keras.layers.Conv2D(256, 2, activation=act, padding='same', kernel_initializer='he_normal')(
            tf.keras.layers.UpSampling2D(size=(2, 2))(conv6))
        merge7 = tf.keras.layers.concatenate([conv3, up7], axis=3)
        conv7 = tf.keras.layers.Conv2D(256, 3, activation=act, padding='same', kernel_initializer='he_normal')(merge7)
        conv7 = tf.keras.layers.Conv2D(256, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv7)
    
        up8 = tf.keras.layers.Conv2D(128, 2, activation=act, padding='same', kernel_initializer='he_normal')(
            tf.keras.layers.UpSampling2D(size=(2, 2))(conv7))
        merge8 = tf.keras.layers.concatenate([conv2, up8], axis=3)
        conv8 = tf.keras.layers.Conv2D(128, 3, activation=act, padding='same', kernel_initializer='he_normal')(merge8)
        conv8 = tf.keras.layers.Conv2D(128, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv8)
    
        up9 = tf.keras.layers.Conv2D(64, 2, activation=act, padding='same', kernel_initializer='he_normal')(
            tf.keras.layers.UpSampling2D(size=(2, 2))(conv8))
        merge9 = tf.keras.layers.concatenate([conv1, up9], axis=3)
        conv9 = tf.keras.layers.Conv2D(64, 3, activation=act, padding='same', kernel_initializer='he_normal')(merge9)
        conv9 = tf.keras.layers.Conv2D(64, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv9)
        conv9 = tf.keras.layers.Conv2D(2, 3, activation=act, padding='same', kernel_initializer='he_normal')(conv9)
        conv10 = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(conv9)
    
        model = tf.keras.Model(inputs=inputs, outputs=conv10)
    
        #model.compile(optimizer = tf.keras.optimizers.Adam(lr = 2e-4), loss = 'binary_crossentropy', metrics = [tf.keras.metrics.Accuracy()])
        #model.compile(optimizer=tf.keras.optimizers.Adam(lr = 5e-6), loss=combo_loss(alpha=1, beta=0.4), metrics=[dice_accuracy])
        model.compile(optimizer=tf.keras.optimizers.Adam(lr = 8e-3), loss = 'binary_crossentropy',metrics=[dice_accuracy])
        #model.compile(optimizer=RMSprop(lr=0.00001), loss=combo_loss, metrics=[dice_accuracy])
    
        if (pretrained_weights):
            model.load_weights(pretrained_weights)
    
        return model

I have tried also the relu activation function and it is the same.

What could be a potential cause?

Christoph Rackwitz
  • 11,317
  • 4
  • 27
  • 36
  • the problem with `binary_cross_entropy` training is that the background affects the results badly, because if the background size was big, the model will have high accuracy, and low loss, due to the background is zero in both the ground truth, and the segmentation(what ever was the result of the segmentation). try to use Intersection over Union IOU instead. – Bilal Jan 04 '21 at 17:29
  • I think that is not the problem because the accuracy does not change. I noticed that if I change the seed, it is training now. –  Jan 04 '21 at 17:40
  • the training will look fine and you will have high accuracy at training, but when it comes to testing with new samples the results will not be satisfactory. – Bilal Jan 04 '21 at 18:35
  • I was always checking the validation accuracy as well, and they're similar so I think the imbalance is not a problem in this case. The main problem is that for some runs it does not train, for others it does. For example I changed the seed and it was training. I had the same issue with the Dice loss as well. –  Jan 04 '21 at 19:52

0 Answers0