1

Context

I have switched to the dataset API (based on this) and this has resulted in a very significant performance boost compared to using queues.

I am using Tensorflow 1.6.

Problem

I have implemented resampling based on the very helpful explanation here.

The problem is that no matter where I place the resampling stage in the input pipeline, the program returns a ResourceExhaustedError. Changing the batch_size does not seem to fix this and this is only resolved when using a fraction of all input files.

My training files (.tfrecords) are ~200 GB in size and split over a few hundred shards, but the dataset API has handled them very well so far and it's only the resampling that is causing this problem.

Input pipeline example

batch_size = 20000
dataset = tf.data.Dataset.list_files(file_list)
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
        tf.data.TFRecordDataset, cycle_length=len(file_list), sloppy=True, block_length=10))
if resample:
    dataset = dataset.apply(tf.contrib.data.rejection_resample(class_func=class_func, target_dist=target_dist, initial_dist= initial_dist,seed=5))
    dataset = dataset.map(lambda _, data: (data))
dataset = dataset.shuffle(5*batch_size,seed=5)
dataset = dataset.apply(tf.contrib.data.map_and_batch(
        map_func=_parse_function, batch_size=batch_size, num_parallel_batches=8))

dataset = dataset.prefetch(10)

return dataset

If anyone has an idea of how to work around this, it would be much appreciated!

daiktas
  • 23
  • 4

0 Answers0