20

There are a number of image operations in TensorFlow used for distorting input images during training, e.g. tf.image.random_flip_left_right(image, seed=None) and tf.image.random_brightness(image, max_delta, seed=None) and several others.

These functions are made for single images (i.e. 3-D tensors with shape [height, width, color-channel]). How can I make them work on a batch of images (i.e. 4-D tensors with shape [batch, height, width, color-channel])?

A working example would be greatly appreciated!

questiondude
  • 772
  • 7
  • 15

5 Answers5

35

One possibility is to use the recently added tf.map_fn() to apply the single-image operator to each element of the batch.

result = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), images)

This effectively builds the same graph as keveman suggests building, but it can be more efficient for larger batch sizes, by using TensorFlow's support for loops.

Innat
  • 16,113
  • 6
  • 53
  • 101
mrry
  • 125,488
  • 26
  • 399
  • 400
  • 1
    Thanks it works! Is there a reason the functions in tf.image don't have this built-in? I have made all my image distortions in a single function preprocess() which is called from tf.map_fn(). I believe this causes the random distortions to be different for all the images because map_fn() calls preprocess() repeatedly with new random values. Please elaborate on the difference between your answer and the other one suggested, and why map_fn() is a better solution. I'm guessing tf.map_fn() loops over the images at run-time, so it doesn't add ops to the graph for each of the images in the batch? – questiondude Aug 15 '16 at 08:20
  • Thanks. I did this, and it worked. But training is now 5 times slower than without transformations, so it is not that efficient :-( – Rémi Sep 14 '17 at 16:02
4

You can call the image operation in a loop and concatenate the result. For example :

transformed_images = []
for i in range(batch_size):
  transformed_images.append(tf.image.random_flip_left_right(image[i, :, :, :]))
retsult = tf.stack(transformed_images)
filiprafaj
  • 105
  • 2
  • 8
keveman
  • 8,427
  • 1
  • 38
  • 46
  • Upvoted as thanks for the quick answer! I thought about something like this, but I believe this would add ops to the TensorFlow graph for each of the images in the batch and hence won't work for varying batch-sizes. I guess I could build a graph for each of the batch-sizes I need, but it seems rather messy. The other answer seems like the proper way to do it. Thanks again though. – questiondude Aug 15 '16 at 08:16
2

TLDR: you can create queue, define reading and processing data for single element of queue and than make batch - all this with TF methods.

I'm not sure how it works but if you use queues and create batches and read images with tensorflow methods you can work with batch as with single image.

I didn't test it on large datasets yet and don't know how good it is (speed, memory consumption and so on). May be for now it's better to create batch by yourself.

I have seen this in cifar10 example. You can see it here https://github.com/tensorflow/tensorflow/tree/r0.10/tensorflow/models/image/cifar10

  1. Firstly they create queue with tf.train.string_input_producer. https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_input.py#L222 You can use different type of queue. For example I try to use tf.train.slice_input_producer for multiple images. You can read about it here Tensorflow read images with labels
  2. Then they make all needed operations as for single image. If they need only reading it is just reading, if they want processing they crop image and do other stuff. Reading is described in read_cifar10. Processing in distorted_inputs, it is here https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_input.py#L138
  3. They pass results of 2 to tf.train.batch or tf.train.shuffle_batch depending on the parameters and return it from inputs() and distorted_inputs() functions.
  4. They read it just like images, labels = cifar10.distorted_inputs() and do following job. It's here https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_train.py#L66
Community
  • 1
  • 1
ckorzhik
  • 758
  • 2
  • 7
  • 21
0

You could use tf.reverse to simulate tf.image.random_flip_left_right and tf.image.random_flip_up_down on 4-D tensors with shape [batch, height, width, channel].

0
random_number = tf.random_uniform([], minval=0, maxval=4, dtype=tf.int32)   
random_batch_flip = tf.where(tf.less(tf.constant(2), random_number), tf.image.flip_left_right(batch), batch)

reference :http://www.guidetomlandai.com/tutorials/tensorflow/if_statement/

Manuel Cuevas
  • 329
  • 2
  • 6
  • Generally, answers are much more helpful if they include an explanation of what the code is intended to do, and why that solves the problem without introducing others. – Tim Diekmann May 26 '18 at 11:51