2

I just stumbled over this question: TensorFlow - Read all examples from a TFRecords at once?

And the first answer suggest using tf.parse_example instead of parsing single examples because that seems to be faster. But the code provided is not complete and I don't know how I can use that. If I batch and then use parse_example I will get a batch of features. That means I need to unpack that batch in order to decode the jpegs? The code from the answer is:

reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(serialized_example, features={
      image/center': tf.VarLenFeature(tf.string),
  })
 image = features['image/center']
 image_decoded = tf.image.decode_jpeg(image.values[0], channels=3)
 return image_decoded

And suggest switching to:

batch = tf.train.batch([serialized_example], num_examples, capacity=num_examples)
parsed_examples = tf.parse_example(batch, feature_spec)

But how can I now decode those parsed_examples?

Salvador Dali
  • 214,103
  • 147
  • 703
  • 753
andre_bauer
  • 850
  • 2
  • 10
  • 18
  • What does happen when you execute it as-is? According to [tf.decode_raw documentation](https://www.tensorflow.org/versions/r0.11/api_docs/python/io_ops.html#decode_raw) it should decode the whole batch of examples. – sygi Nov 15 '16 at 12:50
  • 1
    Ah sorry I copied that code from the other question, not realizing that it's a litte different. I use jpegs... I will edit the question! tf.image.decode_jpeg does not support batches I think – andre_bauer Nov 16 '16 at 00:01

1 Answers1

4

I had the same problem. The way I went about it is by using the higher order operators of TensorFlow, tf.map_fn to be specific:

batch = tf.train.batch([serialized_example], num_examples, capacity=num_examples)

parsed_examples = tf.parse_example(batch, 
    features={
        'image_jpg': tf.FixedLenFeature([], tf.string),
    })

raw_bytes_batch = parsed_examples['image_jpg']

def decode(raw_bytes):
    return tf.image.decode_jpeg(raw_bytes, channels=3)

image_decoded = tf.map_fn(decode, raw_bytes_batch, dtype=tf.uint8,
                          back_prop=False, parallel_iterations=10)
# image_decoded.set_shape([None, ..., ..., 3])

This should run the decode function on the JPEGs in parallel.

sunside
  • 8,069
  • 9
  • 51
  • 74