1

let say I have a data loader of cifar10
if I want to remove some value from the dataloader and make a new dataloader
how should I do it?

def load_data_cifar10(batch_size=128,test=False):
    if not test:
        train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=True,
                                                download=True, transform=transform)
    else:
        train_dset = torchvision.datasets.CIFAR10(root='/mnt/3CE35B99003D727B/input/pytorch/data', train=False,
                                               download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True)
    print("LOAD DATA, %d" % (len(train_loader)))
    return train_loader

1 Answers1

1

You can use the Subset dataset. This takes another dataset as input as well as a list of indices to construct a new dataset. Say you want the first 1000 entries, then you could do

subset_train_dset = torch.utils.data.Subset(train_dset, range(1000))

You can also construct datasets composed of multiple datasets using ConcatDataset dataset, or combinations of ConcatDataset and Subset to build whatever you like

frankenstein_dset = torch.utils.data.ConcatDataset((
    torch.utils.data.Subset(dset1, range(1000)),
    torch.utils.data.Subset(dset2, range(100)))

In your case you would need to either look into the implementation details to determine which indices to keep, or you could write some code to iterate through the original dataset first and save all the indicies you want to keep, then define a Subset with the appropriate indices.

jodag
  • 19,885
  • 5
  • 47
  • 66
  • subset_train_dset.dataset.data.shape. is not working with real dataloader. and for idx,(img,target) in enumerate(dataloader): spit the error : TypeError: 'DataLoader' object is not subscriptable –  Dec 19 '19 at 23:46
  • subset_train_dset = torch.utils.data.Subset(dataloader, range(1000)) for idx,(img,target) in enumerate(subset_train_dset): print(idx ,':' ,img.shape) –  Dec 19 '19 at 23:47
  • datasets don't generally have a `.data` member (that's for specific datasets). Read the docs for more info, but the only things a dataset is required to implement is `__getitem__` and `__len__`. So if you want the length of a dataset object you must use `len(dataset)` for a general solution. – jodag Dec 20 '19 at 01:15