Please find below a potential solution.
For the sake of the demonstration, I am using a python generator instead of TFRecords as input (I am supposing you know how to use TF Dataset to read and parse the files in each folder. Other threads are otherwise covering this, e.g. here).
import tensorflow as tf
import numpy as np
def get_class_generator(class_id, num_el, el_shape=(32, 32), el_dtype=np.int32):
""" Returns a dummy generator,
outputting "num_el" elements of a single class (input data & class label)
"""
def class_generator():
x = 0
for x in range(num_el):
element = np.ones(el_shape, dtype=el_dtype) * x
yield element, class_id
return class_generator
def concatenate_datasets(datasets):
""" Concatenate a list of datasets together.
Snippet by user2781994 (https://stackoverflow.com/a/49069420/624547)
"""
ds0 = tf.data.Dataset.from_tensors(datasets[0])
for ds1 in datasets[1:]:
ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
return ds0
num_classes = 11
class_batch_size = 3
num_classes_per_batch = 5
# note: using 3 instead of 5 for class_batch_size in this example
# just to distinguish between the 2 vars.
# Initializing per-class datasets:
# (note: replace tf.data.Dataset.from_generator(...) to suit your use-case
# e.g. tf.contrib.data.TFRecordDataset(glob.glob(perclass_tfrecords_path))
# .map(your_parsing_function)
class_datasets = [tf.data.Dataset
.from_generator(get_class_generator(
class_id, num_el=np.random.randint(1, 60)
# ^ simulating unequal number of samples per class
), (tf.int32, tf.int32), ([32, 32], []))
.repeat(-1)
.batch(class_batch_size)
for class_id in range(num_classes)]
# Initializing complete dataset:
dataset = (tf.data.Dataset
# Concatenating all the class datasets together:
.zip(tuple(class_datasets))
.flat_map(lambda *args: concatenate_datasets(args))
# Shuffling the class datasets:
.shuffle(buffer_size=num_classes)
# Flattening batches from shape (num_classes_per_batch, class_batch_size, ...)
# into (num_classes_per_batch * class_batch_size, ...):
.flat_map(lambda *args: tf.data.Dataset.from_tensor_slices(args))
# Returning correct number of el. (num_classes_per_batch * class_batch_size):
.batch(num_classes_per_batch * class_batch_size))
# Visualizing results:
next_batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(10):
batch = sess.run(next_batch)
print(">> batch {}".format(i))
print("- inputs shape: {} ; label shape: {}".format(batch[0].shape,batch[1].shape))
print("- class values: {}".format(batch[1]))
Outputs:
>> batch 0
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 1 1 1 0 0 0 10 10 10 2 2 2 9 9 9]
>> batch 1
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 2 2 2 3 3 3 5 5 5 6 6 6]
>> batch 2
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 9 9 9 8 8 8 4 4 4 3 3 3 10 10 10]
>> batch 3
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [7 7 7 8 8 8 6 6 6 6 6 6 2 2 2]
>> batch 4
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [1 1 1 0 0 0 1 1 1 8 8 8 5 5 5]
>> batch 5
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [2 2 2 4 4 4 9 9 9 5 5 5 5 5 5]
>> batch 6
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 7 7 7 3 3 3 9 9 9 7 7 7]
>> batch 7
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [10 10 10 10 10 10 1 1 1 6 6 6 7 7 7]
>> batch 8
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [4 4 4 3 3 3 5 5 5 6 6 6 3 3 3]
>> batch 9
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [8 8 8 9 9 9 2 2 2 8 8 8 0 0 0]