I'd like to train a classifier on one ImageNet dataset (1000 classes each with around 1300 images). For some reason, I need each batch to contain 64 images from a specific class (provided as int
or placeholder). How to do it efficiently with the latest TensorFlow?
This is a follow-up question to How to sample batch from only one class at each iteration.
My current thought is to use tf.data.Dataset.filter
:
specific_class = 2 # as an example
dataset = tf.data.TFRecordDataset(filenames)
# __parser_fun__ produces datum tuple (x, y)
dataset = dataset.map(__parser_fun__, num_parallel_calls=num_threads)
dataset = dataset.shuffle(20000)
# print(dataset) gives <ShuffleDataset shapes: ((3, 128, 128), (1,)),
# types: (tf.float32, tf.int64)>
dataset = dataset.filter(lambda x, y: tf.equal(y[0], specific_class))
dataset = dataset.batch(64)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch, y_batch = iterator.get_next()
A minor problem with filter
is that I need to construct an iterator every time I want to sample from a new class.
Another idea is to use tf.contrib.data.rejection_resample
but it seems prohibitive computationally (or is it?).
I wonder if there is other efficient way to sample batches from a particular class?