1

I am using weighted Binary cross entropy Dice loss for a segmentation problem with class imbalance (80 times more black pixels than white pixels) .

def weighted_bce_dice_loss(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    averaged_mask = K.pool2d(
        y_true, pool_size=(50, 50), strides=(1, 1), padding='same', pool_mode='avg')
    weight = K.ones_like(averaged_mask)
    w0 = K.sum(weight)
    weight = 5. * K.exp(-5. * K.abs(averaged_mask - 0.5))
    w1 = K.sum(weight)
    weight *= (w0 / w1)
    loss = weighted_bce_loss(y_true, y_pred, weight) + dice_loss(y_true, y_pred)
    return loss

Dice coeffecient increased and the loss decreased but at every epoch I am getting a black image as output (all the pixels are labelled black)

enter image description here

Any ideas on why this is happening?

AKSHAYAA VAIDYANATHAN
  • 2,715
  • 7
  • 30
  • 51

0 Answers0