22

Say I am loading MNIST from torchvision.datasets.MNIST, but I only want to load in 10000 images total, how would I slice the data to limit it to only some number of data points? I understand that the DataLoader is a generator yielding data in the size of the specified batch size, but how do you slice datasets?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
iacob
  • 20,084
  • 6
  • 92
  • 119
mikal94305
  • 4,663
  • 8
  • 31
  • 40

3 Answers3

27

You can use torch.utils.data.Subset() e.g. for the first 10,000 elements:

import torch.utils.data as data_utils

indices = torch.arange(10000)
tr_10k = data_utils.Subset(tr, indices)
iacob
  • 20,084
  • 6
  • 92
  • 119
14

It is important to note that when you create the DataLoader object, it doesnt immediately load all of your data (its impractical for large datasets). It provides you an iterator that you can use to access each sample.

Unfortunately, DataLoader doesnt provide you with any way to control the number of samples you wish to extract. You will have to use the typical ways of slicing iterators.

Simplest thing to do (without any libraries) would be to stop after the required number of samples is reached.

nsamples = 10000
for i, image, label in enumerate(train_loader):
    if i > nsamples:
        break

    # Your training code here.

Or, you could use itertools.islice to get the first 10k samples. Like so.

for image, label in itertools.islice(train_loader, stop=10000):

    # your training code here.
entrophy
  • 2,065
  • 14
  • 20
  • 4
    A warning for this method: if you iterate the `train_loader` multiple times in a loop over variable `epoch`, you may have already used all the samples for the training... Because the ``shuffle=True`` option in ``DataLoader`` will shuffle the samples for each epoch. – Libin Wen Aug 28 '19 at 11:52
  • 1
    I keep getting errors like `DataLoader worker (pid(s) 9579) exited unexpectedly` with these methods (on OSX) – Tom Roth Apr 14 '20 at 08:12
14

Another quick way of slicing dataset is by using torch.utils.data.random_split() (supported in PyTorch v0.4.1+). It helps in randomly splitting a dataset into non-overlapping new datasets of given lengths.

So we can have something like the following:

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)

part_tr = torch.utils.data.random_split(tr, [tr_split_len, len(tr)-tr_split_len])[0]
part_te = torch.utils.data.random_split(te, [te_split_len, len(te)-te_split_len])[0]

train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(part_te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)

here you can set tr_split_len and te_split_len as the required split lengths for training and testing datasets respectively.

srihegde
  • 358
  • 3
  • 8