4

I am training a neural network to predict a binary mask on mouse brain images. For this I am augmenting my data with the ImageDataGenerator from keras.

But I have realized that the Data Generator is interpolating the data when applying spatial transformations.

This is fine for the image, but I certainly do not want my mask to contain non-binary values.

Is there any way to choose something like a nearest neighbor interpolation when applying the transformations? I have found no such option in the keras documentation.

To the left is the original binary mask, to the right is the augmented, interpolated mask

(To the left is the original binary mask, to the right is the augmented, interpolated mask)

Code for the images:

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest')
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, seed=1)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(image))
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(image_generator.next()[0]))
plt.axis('off')
plt.savefig('vis/keras_example')
Jimmy2027
  • 313
  • 3
  • 10

1 Answers1

2

I had the same problem with my own binary image data. There are several ways to approach this issue.

Simple answer: I solved it by manually converting results of ImageDataGenerator to binary. If you are manually iterating over the generator(using 'next()' method or using a 'for' loop), so you can simply use numpy 'where' method to convert non-binary values to binary:

import numpy as np

batch = image_generator.next()
binary_images = np.where(batch>0, 1, 0)  ## or batch>0.5 or any other thresholds

Using the preprocessing_function argument in ImageDataGenerator

Another better way is to use preprocessing_function argument in the ImageDataGenerator. As written in the documentation it is possible to specify a custom preprocessing function that will be executed after the data augmentation procedures, so you can specify this function in your data_gen_args as follows:

from keras.preprocessing.image import ImageDataGenerator

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest',
                     preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype))

Note: from my experience the preprocessing_function is executed before the rescale, that is possible to specify also as an argument of the ImageDataGenerator in your data_gen_args. This is not your case but if you will need to specify that argument keep this in mind.

Create a custom generator

Another solution is to write a custom data generator and modify the output of ImageDataGenerator inside it. Then use this new generator to feed model.fit(). Something like this:

batch_size = 64
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, batch_size=batch_size, seed=1)
from tensorflow.keras.utils import Sequence
class MyImageDataGenerator(Sequence):
        def __init__(self, data_size, batch_size):
            self.data_size = data_size
            self.batch_size = batch_size
            super(MyImageDataGenerator).__init__()

        def __len__(self):
            return int(np.ceil(self.data_size / float(self.batch_size)))

        def __getitem__(self, idx):    
            augmented_data = image_generator.next()
            binary_images = np.where(augmented_data>0, 1, 0)
            return binary_images

my_image_generator = MyImageDataGenerator(data_size=len(image), batch_size=batch_size)
model.fit(my_image_generator, epochs=50)

Also above data generator is a simple data generator. If you need, you can customize it and add your lables (like this) or multimodal data, etc.

Aelius
  • 1,029
  • 11
  • 22
Masoud Maleki
  • 813
  • 6
  • 13