I'm trying to split one of the Pytorch custom datasets (MNIST) into a training set and a validation set as follows:
def get_train_valid_splits(data_dir,
batch_size,
random_seed=1,
valid_size=0.2,
shuffle=True,
num_workers=4,
pin_memory=False):
normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST
# define transforms
valid_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# load the dataset
train_dataset = datasets.MNIST(root=data_dir, train=True,
download=True, transform=train_transform)
valid_dataset = datasets.MNIST(root=data_dir, train=True,
download=True, transform=valid_transform)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(valid_size * dataset_size))
if shuffle == True:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = sampler.SubsetRandomSampler(train_idx)
valid_sampler = sampler.SubsetRandomSampler(valid_idx)
print(len(train_sampler))
print(len(valid_sampler))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory)
print(len(train_loader.dataset))
print(len(valid_loader.dataset))
return (train_loader, valid_loader)
After calling the function I notice that the results of the indices to sample look right, 48000 and 12000:
print(len(train_sampler))
print(len(valid_sampler))
But when I look at the length of the data set associated with train_loader and valid_loader:
print(len(train_loader.dataset))
print(len(valid_loader.dataset))
I get the same length for both: 60000! Any idea what is going on here? Why is it giving the same length for both, even though I clearly split it by indices?