0

I'm trying to split one of the Pytorch custom datasets (MNIST) into a training set and a validation set as follows:

def get_train_valid_splits(data_dir,
                           batch_size,
                           random_seed=1,
                           valid_size=0.2,
                           shuffle=True,
                           num_workers=4,
                           pin_memory=False):

    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transforms
    valid_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

    # load the dataset
    train_dataset = datasets.MNIST(root=data_dir, train=True,
                download=True, transform=train_transform)

    valid_dataset = datasets.MNIST(root=data_dir, train=True,
                download=True, transform=valid_transform)

    dataset_size = len(train_dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(valid_size * dataset_size))

    
    if shuffle == True:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = sampler.SubsetRandomSampler(train_idx)
    valid_sampler = sampler.SubsetRandomSampler(valid_idx)

    print(len(train_sampler))
    print(len(valid_sampler))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                    batch_size=batch_size, sampler=train_sampler,
                    num_workers=num_workers, pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                    batch_size=batch_size, sampler=valid_sampler,
                    num_workers=num_workers, pin_memory=pin_memory)

    print(len(train_loader.dataset))
    print(len(valid_loader.dataset))

    return (train_loader, valid_loader)

After calling the function I notice that the results of the indices to sample look right, 48000 and 12000:

print(len(train_sampler))
print(len(valid_sampler))

But when I look at the length of the data set associated with train_loader and valid_loader:

print(len(train_loader.dataset))
print(len(valid_loader.dataset))

I get the same length for both: 60000! Any idea what is going on here? Why is it giving the same length for both, even though I clearly split it by indices?

user6496380
  • 43
  • 1
  • 7

2 Answers2

0

It's because the dataloader doesn't modify the dataset you pass it, but "applies" things like batch size, samplers, etc ... to the data when you try to access by iterating it. You're issue is you're using len(loader.dataset) which gives you the length of the provided dataset without modification, when you really wanted len(loader) which is the length of the dataset after "applying" things like batch size and samplers.

import torch
import numpy as np

dataset = np.random.rand(100,200)
sampler = torch.utils.data.SubsetRandomSampler(list(range(70)))

loader = torch.utils.data.DataLoader(dataset, sampler=sampler)
print(len(loader)) 
>>> 70
print(len(loader.dataset))
>>> 100

Note: The result of len will be affected by batch size:

# with batch size
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=2)
print(len(loader)) 
>>> 35
print(len(loader.dataset))
>>> 100
Jay Mody
  • 3,727
  • 1
  • 11
  • 27
  • Thanks for this! so len(loader) will give you the num_samples/batch_size right? So to get the size of the full dataset, would you need to do len(loader)*batch_size, or is there an easier way to do that? – user6496380 Jul 25 '20 at 19:05
  • @user6496380 just updated by response, but yes, that would be correct, you would need to do multiply the len by batch_size. However, note, if num_samples does not divide perfectly into the batch size, then you won't get an exact number when you remultiply batch_size to len. – Jay Mody Jul 25 '20 at 19:07
  • I see, thanks! One last question: In addition to using train_loader, valid_loader, can you split the dataset yourself on the same indices that DataLoader did (using sampler=...), to get train_dataset and valid_dataset? All the data essentially. – user6496380 Jul 25 '20 at 19:12
  • You can always use something like `torch.utils.data.random_split()`. In this scenario, you would use a random sampler instead of a subset random sampler since the datasets are already split before being passed to the dataloaders. – Jay Mody Jul 25 '20 at 19:31
0

The reason the train_loader and valid_loader are the same length is because you used the same data for train_dataset and valid_dataset.

You want

valid_dataset = datasets.MNIST(root=data_dir, train=False,
                               download=True, transform=valid_transform)

(not train=True) to download the validation set.

James Hirschorn
  • 7,032
  • 5
  • 45
  • 53