0

I'm totally new to Keras, I've only been working for a few days, so I'm pretty inexperienced.

I was able to train a U-Net network that works with a class, then input an RGB image and a grayscale mask for training, with the following code:

def train_generator():

    while True:
        for start in range(0, len(ids_train_split), batch_size):
            x_batch = []
            y_batch = []
            end = min(start + batch_size, len(ids_train_split))
            ids_train_batch = ids_train_split[start:end]
            for id in ids_train_batch.values:

                img_name = 'IMG_'+str(id).split('_')[2]
                image_path = os.path.join("input", "train", "{}.JPG".format(str(img_name)))
                mca_mask_path = os.path.join("input", "train_mask", "{}.png".format(id))

                img = cv2.imread(image_path)
                img = cv2.resize(img, (input_size, input_size))

                mask_mca = cv2.imread(mca_mask_path, cv2.IMREAD_GRAYSCALE)
                mask_mca = cv2.resize(mask_mca, (input_size, input_size))

                img = randomHueSaturationValue(img,
                                               hue_shift_limit=(-50, 50),
                                               sat_shift_limit=(-5, 5),
                                               val_shift_limit=(-15, 15))
                img, mask = randomShiftScaleRotate(img, mask,
                                                   shift_limit=(-0.0625, 0.0625),
                                                   scale_limit=(-0.1, 0.1),
                                                   rotate_limit=(-0, 0))
                img, mask = randomHorizontalFlip(img, mask)
                mask = np.expand_dims(mask, axis=2)
                x_batch.append(img)
                y_batch.append(mask)
            x_batch = np.array(x_batch, np.float32) / 255
            y_batch = np.array(y_batch, np.float32) / 255
            yield x_batch, y_batch

And this is my U-Net model:

def get_unet_1(pretrained_weights=None, input_shape=(1024, 1024, 3), num_classes=1, learning_rate=0.0001):
    inputs = Input(shape=input_shape)
    # 1024

    down0b = Conv2D(8, (3, 3), padding='same')(inputs)
    down0b = BatchNormalization()(down0b)
    down0b = Activation('relu')(down0b)
    down0b = Conv2D(8, (3, 3), padding='same')(down0b)
    down0b = BatchNormalization()(down0b)
    down0b = Activation('relu')(down0b)
    down0b_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0b)
    # 512

    down0a = Conv2D(16, (3, 3), padding='same')(down0b_pool)
    down0a = BatchNormalization()(down0a)
    down0a = Activation('relu')(down0a)
    down0a = Conv2D(16, (3, 3), padding='same')(down0a)
    down0a = BatchNormalization()(down0a)
    down0a = Activation('relu')(down0a)
    down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)
    # 256

    down0 = Conv2D(32, (3, 3), padding='same')(down0a_pool)
    down0 = BatchNormalization()(down0)
    down0 = Activation('relu')(down0)
    down0 = Conv2D(32, (3, 3), padding='same')(down0)
    down0 = BatchNormalization()(down0)
    down0 = Activation('relu')(down0)
    down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
    # 128

    down1 = Conv2D(64, (3, 3), padding='same')(down0_pool)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1 = Conv2D(64, (3, 3), padding='same')(down1)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
    # 64

    down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2 = Conv2D(128, (3, 3), padding='same')(down2)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
    # 32

    down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3 = Conv2D(256, (3, 3), padding='same')(down3)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)
    # 16

    down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4 = Conv2D(512, (3, 3), padding='same')(down4)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)
    # 8

    center = Conv2D(1024, (3, 3), padding='same')(down4_pool)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    center = Conv2D(1024, (3, 3), padding='same')(center)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    # center

    up4 = UpSampling2D((2, 2))(center)
    up4 = concatenate([down4, up4], axis=3)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    # 16

    up3 = UpSampling2D((2, 2))(up4)
    up3 = concatenate([down3, up3], axis=3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    # 32

    up2 = UpSampling2D((2, 2))(up3)
    up2 = concatenate([down2, up2], axis=3)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    # 64

    up1 = UpSampling2D((2, 2))(up2)
    up1 = concatenate([down1, up1], axis=3)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    # 128

    up0 = UpSampling2D((2, 2))(up1)
    up0 = concatenate([down0, up0], axis=3)
    up0 = Conv2D(32, (3, 3), padding='same')(up0)
    up0 = BatchNormalization()(up0)
    up0 = Activation('relu')(up0)
    up0 = Conv2D(32, (3, 3), padding='same')(up0)
    up0 = BatchNormalization()(up0)
    up0 = Activation('relu')(up0)
    up0 = Conv2D(32, (3, 3), padding='same')(up0)
    up0 = BatchNormalization()(up0)
    up0 = Activation('relu')(up0)
    # 256

    up0a = UpSampling2D((2, 2))(up0)
    up0a = concatenate([down0a, up0a], axis=3)
    up0a = Conv2D(16, (3, 3), padding='same')(up0a)
    up0a = BatchNormalization()(up0a)
    up0a = Activation('relu')(up0a)
    up0a = Conv2D(16, (3, 3), padding='same')(up0a)
    up0a = BatchNormalization()(up0a)
    up0a = Activation('relu')(up0a)
    up0a = Conv2D(16, (3, 3), padding='same')(up0a)
    up0a = BatchNormalization()(up0a)
    up0a = Activation('relu')(up0a)
    # 512

    up0b = UpSampling2D((2, 2))(up0a)
    up0b = concatenate([down0b, up0b], axis=3)
    up0b = Conv2D(8, (3, 3), padding='same')(up0b)
    up0b = BatchNormalization()(up0b)
    up0b = Activation('relu')(up0b)
    up0b = Conv2D(8, (3, 3), padding='same')(up0b)
    up0b = BatchNormalization()(up0b)
    up0b = Activation('relu')(up0b)
    up0b = Conv2D(8, (3, 3), padding='same')(up0b)
    up0b = BatchNormalization()(up0b)
    up0b = Activation('relu')(up0b)
    # 1024

    classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b)

    model = Model(inputs=inputs, outputs=classify)

    model.compile(optimizer=RMSprop(lr=learning_rate), loss=make_loss('bce_dice'), metrics=[dice_coef, 'accuracy'])

    if pretrained_weights:
        model.load_weights(pretrained_weights)

    return model

Now I've to modify the problem and make it a multi-class classifier, so I do not work with a mask anymore but two. So I have two types of grasycale masks (Mca_mask and NotMca_mask of the same train img), in this case what is the standard practice? Merge the two masks into one?

Alex Kulinkovich
  • 4,408
  • 15
  • 46
  • 50
Mithosk93
  • 3
  • 6
  • As far as I understand from your code, you are doing "classification" with just one class, which is in your case the "mask", right? I think you are mistaking the image classification concepts. Could you explain what is the problem that your model is intended to solve? – Hemerson Tacon Oct 18 '18 at 15:12
  • I was trying to figure out how to organize masks for a multiclass problem. From what I understand my RGB images so organized: [1,1,1] -> Dog [2,2,2] -> Cat [0,0,0] -> Unknown I have to turn them into vectors like: [1,0] = dog, [0,1] = cat, [0,0] = unknown Quite right? Is there a quick way to do it, a command or something like that? I do not want to iterate every pixel for every image – Mithosk93 Oct 19 '18 at 09:19

1 Answers1

0

On this line we can see that your output layer is applying sigmoid:

classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b)

Which means all your outputs are transformed to be between [0,1] without any dependence between them. This is what you want for multi-class classification. As a side note another common way to transform your output layer to the [0,1] range is to apply softmax - this is not good for multiclass because as one class grows in confidence the others must necessarily decrease.

Your loss function is defined as Binary Cross Entropy on this line:

model.compile(optimizer=RMSprop(lr=learning_rate), loss=make_loss('bce_dice'), metrics=[dice_coef, 'accuracy'])

Which is appropriate for all types of classification (single class or multi-class), and requires outputs in the [0,1] range.

So basically you're all set to do multi-class classification as you are configured now. All you need to do is create multi-class labels. For example, if your classes are dog, cat, bird, horse, goat and an image has a dog & cat in it, your label would be [1, 1, 0, 0, 0], and you can train the network on that as-is.

David Parks
  • 30,789
  • 47
  • 185
  • 328
  • Thank you, I corrected the mistake! The problem in question has two classes, so I created masks in the following way: RGB image where for each pixel, if it represents a cat: [1, 1, 1] if it represents a dog: [2, 2, 2]. Is there a faster and more concise way to create a vector in the form [0,1] for a dog and [1,0] for a cat? – Mithosk93 Oct 19 '18 at 09:10
  • This is not the optimal way to identify classes, and it's certainly not going to work with cross entropy loss which requires an output in the `[0,1]` range. If you are doing pixel level classification you should not have an RGB image, you should have a 1-channel (e.g. grayscale image) for the class labels. What you should do is have 1 channel for each class. E.g. if you have 5 classes you have an image with 5 "color channels". Each channel will be a `[0,1]` value for the class of the pixel. If your output image is 32x32, your labels would be 32x32x5 for 5 labels then. – David Parks Oct 19 '18 at 17:52
  • Here's a nice looking tutorial on the subject that might help: https://medium.com/nanonets/how-to-do-image-segmentation-using-deep-learning-c673cc5862ef – David Parks Oct 19 '18 at 17:53