1

I'm performing pixel-wise multi-class classification using a U-Net architecture in Keras (TF backend) on many 256x256 images. I've one-hot encoded my outputs using a data generator, making my output 256x256x32 arrays (I have 32 different classes, the classes are represented as pixel values, which are integers from 0-31 in the 256x256 "mask" images).

However, most of the ground truth arrays are empty - in other words, the most common class by far is 0. When I train my U-Net, it seems to overfit to the 0 class. The loss is low and accuracy very high, but only because ~99% of the ground truth is 0, so the U-Net just outputs a bunch of 0s, whereas I only really care about the other 31 classes (as in, how well it can classify the rest of the classes in the ground truth).

Is there a way to "weight" certain classes more than others when calculating the loss function (and if so, would this approach be appropriate)? I'm not sure if it's an intrinsic problem with my data, or a problem with my approach. Here's my U-Net:

def unet(pretrained_weights = None,input_size = (256,256,1)):
inputs = keras.engine.input_layer.Input(input_size)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
#drop4 = Dropout(0.5)(conv4)
drop4 = SpatialDropout2D(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
#drop5 = Dropout(0.5)(conv5)
drop5 = SpatialDropout2D(0.5)(conv5)

up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = concatenate([drop4,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3,up7], axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2,up8], axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1,up9], axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(32, 1, activation = 'softmax')(conv9)
#conv10 = Flatten()(conv10)
#conv10 = Dense(65536, activation = 'softmax')(conv10)
flat10 = Reshape((65536,32))(conv10)
#conv10 = Conv1D(1, 1, activation='linear')(conv10)

model = Model(inputs = inputs, outputs = flat10)

opt = Adam(lr=1e-6,clipvalue=0.01)
model.compile(optimizer = opt, loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])
#model.compile(optimizer = Adam(lr = 1e-6), loss = 'binary_crossentropy', metrics = ['accuracy'])
#model.compile(optimizer = Adam(lr = 1e-4),

#model.summary()

if(pretrained_weights):

    model.load_weights(pretrained_weights)

return model

Please let me know if more information is needed to diagnose the problem.

A. LaBella
  • 427
  • 1
  • 4
  • 13
  • 1
    You could perhaps try out a weighted categorical crossentropy function, and assign more weight to the rare classes? This could be a possible implementation: https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d – sdcbr Jan 11 '19 at 16:42
  • 2
    Read about weighting classes and/or samples. A related question can be found here: https://stackoverflow.com/questions/43459317/keras-class-weight-vs-sample-weights-in-the-fit-generator and another one here: https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras – Luke DeLuccia Jan 11 '19 at 16:43
  • Ah yes, I overlooked that, no custom loss function required :) – sdcbr Jan 11 '19 at 16:47

1 Answers1

1

A common solution to deal with imbalanced classes is to weight some classes more than others. This is easy in Keras with the optional class_weight parameter during training.

model.fit(x, y, class_weight=class_weight)

You can either define the class weights yourself in a dict:

class_weight = {0: 1, 1: 100}

Or, you can use the sklearn function compute_class_weight to generate weights from your data automatically.

class_weights = class_weight.compute_class_weight('balanced', np.unique(y), y)
Luke DeLuccia
  • 541
  • 6
  • 16