I'd like to train a classifier on one ImageNet dataset (1000 classes each with around 1300 images). For some reason, I need each batch to contain 64 images from the same class, and consecutive batches from different classes. Is it possible (and efficient) with the latest TensorFlow?
tf.contrib.data.sample_from_datasets
in TF 1.9 allows sampling from a list of tf.data.Dataset
objects, with weights
indicating the probabilities. I wonder if the following idea makes sense:
- Save data of each class as a separate tfrecord file.
- Pass a
tf.data.Dataset.from_generator
object as theweights
. The object samples from a Categorical distribution such that each sample looks like[0,...,0,1,0,...,0]
with 9990
s and 11
; - Create 1000
tf.data.Dataset
objects, each linked a tfrecord file.
I thought, in this way, maybe at each iteration, sample_from_datasets
will first sample a sparse weight vector that indicates which tf.data.Dataset
to sample from, then same from that class.
Is it correct? Are there any other efficient ways?
Update
As kindly suggested by P-Gn, one way to sample data from one class would be:
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(some_parser_fun) # parse one datum from tfrecord
dataset = dataset.shuffle(buffer_size)
if sample_same_class:
group_fun = tf.contrib.data.group_by_window(
key_func=lambda data_x, data_y: data_y,
reduce_func=lambda key, d: d.batch(batch_size),
window_size=batch_size)
dataset = dataset.apply(group_fun)
else:
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
data_batch = dataset.make_one_shot_iterator().get_next()
A follow-up question can be found at How to sample batch from a specific class?