There is a very similar answer by @mrry here.
Sampling balanced batches
In face recognition we often use triplet loss (or similar losses) to train the model. The usual way to sample triplets to compute the loss is to create a balanced batch of images where we have for instance 10 different classes (i.e. 10 different people) with 5 images each. This gives a total batch size of 50 in this example.
More generally the problem is to sample num_classes_per_batch
(10 in the example) classes, and then sample num_images_per_class
(5 in the example) images for each class. The total batch size is:
batch_size = num_classes_per_batch * num_images_per_class
Have one dataset for each class
The easiest way to deal with a lot of different classes (100,000 in MS-Celeb) is to create one dataset for each class.
For instance you can have one tfrecord for each class and create the datasets like this:
# Build one dataset per class.
filenames = ["class_0.tfrecords", "class_1.tfrecords"...]
per_class_datasets = [tf.data.TFRecordDataset(f).repeat(None) for f in filenames]
Sample from the datasets
Now we would like to be able to sample from these datasets. For instance we want the following labels in our batch:
1 1 1 3 3 3 9 9 9 4 4 4
This corresponds to num_classes_per_batch=4
and num_images_per_class=3
.
To do this we will need to use features that will be released in r1.9
. The function should be called tf.contrib.data.choose_from_datasets
(see here for a discussion on this).
It should look like:
def choose_from_datasets(datasets, selector):
"""Chooses elements with indices from selector among the datasets in `datasets`."""
So we create this selector
which will output 1 1 1 3 3 3 9 9 9 4 4 4
and combine it with datasets
to obtain our final dataset that will output balanced batches:
def generator(_):
# Sample `num_classes_per_batch` classes for the batch
sampled = tf.random_shuffle(tf.range(num_classes))[:num_classes_per_batch]
# Repeat each element `num_images_per_class` times
batch_labels = tf.tile(tf.expand_dims(sampled, -1), [1, num_images_per_class])
return tf.to_int64(tf.reshape(batch_labels, [-1]))
selector = tf.contrib.data.Counter().map(generator)
selector = selector.apply(tf.contrib.data.unbatch())
dataset = tf.contrib.data.choose_from_datasets(datasets, selector)
# Batch
batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)
You can test this with the nightly TensorFlow build and by using DirectedInterleaveDataset
as a workaround:
# The working option right now is
from tensorflow.contrib.data.python.ops.interleave_ops import DirectedInterleaveDataset
dataset = DirectedInterleaveDataset(selector, datasets)
I also wrote about this workaround here.