0

I'm trying to build in some debugging code into by tensorflow dataset pipeline. Basically if tfrecord parsing fails on a certain file, I'd like to be able figure out which file that is. My dream would be to run a number of asserts in my parsing_function that provide the filename if they fail.

My pipeline looks something like this:

tf.data.Dataset.from_tensor_slices(file_list)
        .apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=4))
        .map(parse_func, num_parallel_calls=params.num_cores)
        .map(_func_for_other_stuff)

Ideally I'd pass the filename through in the parallel_interleave step, but if I have the anonymous function return a filename, tfrecordataset tuple, I get:

TypeError: `map_func` must return a `Dataset` object.

I've also tried to include the filename in the file itself like this question, but am having issues here because filenames are of variable length.

Luke
  • 6,699
  • 13
  • 50
  • 88

1 Answers1

1

The return value of the function passed to tf.contrib.data.parallel_interleave() must be a tf.data.Dataset. Therefore you can solve this by attaching the filename tensor to each element of the TFRecordDataset, using tf.data.Dataset.zip() as follows:

def read_records_func(filename):
  records = tf.data.TFRecordDataset(filename)

  # Create a dataset from the filename tensor and repeat it indefinitely.
  filename_as_dataset = tf.data.Dataset.from_tensors(filename).repeat(None)

  return tf.data.Dataset.zip((filename_as_dataset, records))

dataset = (tf.data.Dataset.from_tensor_slices(file_list)
           .apply(tf.contrib.data.parallel_interleave(read_records_func, cycle_length=4))
           .map(parse_func, num_parallel_calls=params.num_cores)
           .map(_func_for_other_stuff))
mrry
  • 125,488
  • 26
  • 399
  • 400