4

Assume I am using the following calls:

trainset = torchvision.datasets.ImageFolder(root="imgs/", transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=1)

As far as I can tell, this defines the trainset as consisting of all the images in the folder "images", with labels as defined by the specific folder location.

My question is - Is there any direct/easy way to define the trainset to be a sub-sample of the images in this folder? For example, define trainset to be a random sample of 10 images from every sub-folder?

Henry Ecker
  • 34,399
  • 18
  • 41
  • 57
Dr. John
  • 273
  • 3
  • 13
  • I do not think there is a *fast* solution for that kind of problem. Maybe the best solution is to create a custom sampler but it is not a one-line solution. Tell me if you are interested on it anyways and I can try to help you. – Manuel Lagunas Jun 11 '18 at 15:57

1 Answers1

5

You can wrap the class DatasetFolder (or ImageFolder) in another class to limit the dataset:

class LimitDataset(data.Dataset):
    def __init__(self, dataset, n):
        self.dataset = dataset
        self.n = n

    def __len__(self):
        return self.n

    def __getitem__(self, i):
        return self.dataset[i]

You can also define some mapping between the index in LimitDataset and the index in the original dataset to define more complex behavior (such as random subsets).

If you want to limit the batches per epoch instead of the dataset size:

from itertools import islice
for data in islice(dataloader, 0, batches_per_epoch):
    ...

Note that if you use this shuffle, the dataset size will be the same, but the data that each epoch will see will be limited. If you don't shuffle the dataset this will also limit the dataset size.

Thaiminhpv
  • 314
  • 2
  • 10
Fábio Perez
  • 23,850
  • 22
  • 76
  • 100
  • Perez: Thanks for your answer. Will you please add to your answer the lines which I should write in order to use the LimitDataSet class you just wrote (instead of Imagefolder)? I am new with python and Pytorch, and having a lot of trouble with this. THanks! – Dr. John Jun 12 '18 at 06:49
  • `LimitDataset(trainset, n=200)` to limit trainset for 200 data points, for instance. – Fábio Perez Jun 13 '18 at 12:41
  • Great. Thanks a lot! – Dr. John Jun 13 '18 at 12:58