I am currently trying to implement my own function to do paired data augmentation in tensorflow
. To do this, I need to apply random image transformations to the input image and relevant transformations to the output mask (i.e. rotations and flipping). However, image flipping/reversing is not working as appropriate. A code example of what I have at the moment is as follows:
def paired_data_augmentation(image_masks_tuple):
#the image_masks_tuple contains an image as its first
#and masks as the remaining elements
def flipper(image_masks,mode = 'lr'):
image = image_masks[0]
masks = image_masks[1:]
if mode == 'lr':
image_masks = [
tf.reverse(m,[0]) for m in image_masks]
elif mode == 'ud':
image_masks = [
tf.reverse(m,[1]) for m in image_masks]
image = image_masks[0]
masks = image_masks[1:]
return [image,*new_masks]
image = image_masks_tuple[0]
masks = image_masks_tuple[1:]
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
flip_lr = tf.random_uniform([])
flip_ud = tf.random_uniform([])
tmp = [image,*masks]
tmp = tf.cond(
tf.greater_equal(flip_lr,0.5),
lambda: flipper(tmp,mode = 'lr'),
lambda: tmp)
set_shape(tmp)
tmp = tf.cond(
tf.greater_equal(flip_ud,0.5),
lambda: flipper(tmp,mode = 'ud'),
lambda: tmp)
set_shape(tmp)
return tmp
What I was expecting this to do was to flip all images/masks in the same way if flip_lr > 0.5
or flip_ud > 0.5
, but what is actually happening is that some images will be flipped without the masks getting flipped and v/v. Has someone experienced this and knows how to solve it?
Thanks in advance