If you want to shuffle all data sets, you have this method:
Note: shuffle(dataset.cardinality()) loads the full dataset into
memory so that it can be shuffled. This will cause a memory overflow
(OOM) error if the dataset is too large, so full-shuffle should only
be used for datasets that are known to fit in the memory, such as
datasets of filenames or other small datasets.
But you can see that it will cause an overflow of memory (OOM), if you don't have enough memory.
So I did this method, to play with the memory.
Please Note:
- I use it for displaying my test dataset or exploring.
- I do not recommend using it for training and validation. Please use Tensorflow's native methods in this case.
def tf_shuffle_dataset(dataset, batch_size, seed):
"""
Shuffles a TensorFlow dataset memory-preservingly using a batch-based method and also shuffles the batches themselves.
Args:
- dataset (tf.data.Dataset): The input dataset to shuffle.
- batch_size (int): Size of each batch.
- seed (int, optional): Seed for shuffle reproducibility.
Returns:
- tf.data.Dataset: Shuffled dataset.
Example:
--------
Let's consider a dataset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] and batch_size = 2.
1. The dataset is divided into the following batches:
[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]
2. Each batch is shuffled. Let's assume the shuffled batches are:
[2, 1], [4, 3], [6, 5], [8, 7], [10, 9] (Note: The actual shuffle might differ)
3. The order of these shuffled batches is then shuffled. Let's assume the shuffled order is:
[4, 3], [2, 1], [8, 7], [10, 9], [6, 5] (Note: The actual shuffle might differ)
4. These batches are concatenated together to give the final shuffled dataset:
[4, 3, 2, 1, 8, 7, 10, 9, 6, 5]
"""
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("The provided dataset is not an instance of tf.data.Dataset.")
# Split the dataset into batches
num_elements = sum(1 for _ in dataset)
num_batches = num_elements // batch_size
batches = [dataset.skip(i * batch_size).take(batch_size) for i in range(num_batches)]
# Shuffle each batch individually
shuffled_batches = [batch.shuffle(buffer_size=batch_size, seed=seed) for batch in batches]
# Shuffle the order of batches themselves
batch_order = tf.random.shuffle(tf.range(num_batches), seed=seed)
# Merge the shuffled batches to create the final dataset
shuffled_dataset = shuffled_batches[0]
for i in tqdm(batch_order[1:], desc="Shuffling dataset", unit="batch"):
shuffled_dataset = shuffled_dataset.concatenate(shuffled_batches[i.numpy()])
return shuffled_dataset