5

I want to train a classifier on ImageNet dataset (1000 classes) and I need each batch to contain 64 images from the same class and consecutive batches from different classes. So far based on @shai's suggestion and this post I have

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os


class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(num_classes)]
        for i, (data, class_label) in enumerate(data):
            # create a list of lists, where every sublist containts the indices of
            # the samples that belong to the class_label
            self.indices[class_label].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]


class BatchSampler:
    def __init__(self, classes, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.min_class_size = min([len(x) for x in classes])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.classes)))
        random.shuffle(self.class_range)

        assert batch_size < self.min_class_size, 'batch_size should be at least {}'.format(self.min_class_size)

    def __iter__(self):
        batches = []
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            batches.append(np.random.choice(self.classes[batch_class], self.batch_size))
        return iter(batches)


def main():
    # Code about
    _train_dataset = DS(train_dataset, train_dataset.num_classes)
    _batch_sampler = BatchSampler(_train_dataset.classes(), batch_size=args.batch_size)
    _train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
    labels = []
    for i, (inputs, _labels) in enumerate(_train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    # Samplers are initialized to None and train_sampler will be replaced
    train_sampler, val_sampler = None, None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    main()

which prints: Length of traversed unique labels: 100.

However, creating self.indices in the for loop takes a lot of time. Is there a more efficient way to construct this sampler?

EDIT: yield implementation

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
from tqdm import tqdm
import os.path


class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data
        self.data_len = len(data)

        indices = [[] for _ in range(num_classes)]

        for i, (_, class_label) in tqdm(enumerate(data), total=len(data), miniters=1,
                                        desc='Building class indices dataset..'):
            indices[class_label].append(i)

        self.indices = indices

    def per_class_sample_indices(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.data_len


class BatchSampler:
    def __init__(self, per_class_sample_indices, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.per_class_sample_indices = per_class_sample_indices
        self.n_batches = sum([len(x) for x in per_class_sample_indices]) // batch_size
        self.min_class_size = min([len(x) for x in per_class_sample_indices])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.per_class_sample_indices)))
        random.shuffle(self.class_range)

    def __iter__(self):
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            if self.batch_size <= len(self.per_class_sample_indices[batch_class]):
                batch = np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size)
                # batches.append(np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size))
            else:
                batch = self.per_class_sample_indices[batch_class]
            yield batch

    def n_batches(self):
        return self.n_batches


def main():
    file_path = 'a_file_path'
    file_name = 'per_class_sample_indices.pt'
    if not os.path.exists(os.path.join(file_path, file_name)):
        print('File: {} does not exists. Create it.'.format(file_name))
        per_class_sample_indices = DS(train_dataset, num_classes).per_class_sample_indices()
        torch.save(per_class_sample_indices, os.path.join(file_path, file_name))
    else:
        per_class_sample_indices = torch.load(os.path.join(file_path, file_name))
        print('File: {} exists. Do not create it.'.format(file_name))

    batch_sampler = BatchSampler(per_class_sample_indices,
                                 batch_size=args.batch_size)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        # batch_size=args.batch_size,
        # shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        # sampler=train_sampler,
        batch_sampler=batch_sampler
    )

    # We do not use sampler for the validation
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True, sampler=None)

    labels = []
    for i, (inputs, _labels) in enumerate(train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        num_classes = len(train_dataset.classes)

    main()

A similar post but in TensorFlow can be found here

Thoth
  • 993
  • 12
  • 36

2 Answers2

3

Your code seems fine. The issue here is not the sampler but the preprocessing step you are required to perform in order to sort out the instance indices by their class. Since this is always the same sort, I recommend you cache this information (the data contained inside of self.indices) on your file system such that you avoid having to reconstruct it on every dataset load. You can do so using either numpy.save or torch.save.

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Ivan thanks for your interest. I will store the indices as you said. In my last execution it took some hours to create the indices. However, after creating `indices` I got `RuntimeError: CUDA out of memory.` while I was not getting this message without the `batch_sampler`. Is it possible `indices` to capture so much space? Can we overcome it? Is there more efficient way? The code I am using is [this](https://github.com/pytorch/examples/tree/main/imagenet). – Thoth Nov 05 '22 at 11:36
  • Shall we avoid creating the list in the sampler and somehow return one batch at a time on the fly? – Thoth Nov 05 '22 at 12:27
  • The memory error doesn't have something to do with the sampler or dataset initialization step since it is happening on the CPU, this would require further investigation. Are you using ImageNet-21K or ImageNet-1K? – Ivan Nov 05 '22 at 13:00
  • I am using the version `ImageNet-ILSVRC2012` downloaded from [here](https://www.image-net.org/download.php). I do not know the references ImageNet-21K and ImageNet-1K. The GPU I am using has 40GB memory. – Thoth Nov 05 '22 at 13:17
  • Please ignore my later messages about the GPU memory. I see that that GPU is also occupied by others which I didn't noticed. I will make a few tests the next days and I will come back. In the mean while, is it more memory (ram) efficient to implement the sampler with yield? Could you please take a look in my post? – Thoth Nov 05 '22 at 13:48
  • 1
    Yes it will be more efficient since the yielded value will be computed on request (on the fly) without needing to precompute the whole sequence of sampled instances. – Ivan Nov 07 '22 at 16:01
2

You should write your own batch_sampler class for the DataLoader.

Shai
  • 111,146
  • 38
  • 238
  • 371
  • I am reading [this](https://stackoverflow.com/questions/66065272/customizing-the-batch-with-specific-elements) post that talks about batch_sampler. Can you provide a similar example on my exact problem which I can study? Sorry I am new in PyTorch and I am looking a fail safe approach in my problem. – Thoth Nov 02 '22 at 10:57
  • I think the post [here](https://stackoverflow.com/questions/66065272/customizing-the-batch-with-specific-elements) provides the necessary information to build most of the functionalities. The only think that bothers me is that in the end not all classes are passed which means some samples are not passed too. Could you please take a look in my original post and give a suggestion on how to make it work correctly? – Thoth Nov 02 '22 at 15:14
  • @Thoth maybe after [CVPR deadline](https://cvpr2023.thecvf.com/) – Shai Nov 02 '22 at 15:26
  • 1
    Good luck with the deadline and especially with the future results! – Thoth Nov 02 '22 at 15:35