0

I have two tensors:

x[train], y[train]

And the shape is

(311, 3, 224, 224), (311) # 311 Has No Information

I want to use DataLoader to load them batch by batch, the code I write is:

from torch.utils.data import Dataset

class KD_Train(Dataset):

    def __init__(self,a,b):
        self.imgs = a
        self.index = b

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self,index):
        return self.imgs, self.index

kdt = KD_Train(x[train], y[train])

train_data_loader = Data.DataLoader(
    kdt,
    batch_size = 64,
    shuffle = True,
    num_workers = 0)

for step, (a,b) in enumerate (train_data_loader):
    print(a.shape)
    break

But it shows:

(64, 311, 3, 224, 224)

the DataLoader just add a dimension directly instead of choosing some batches, anyone know what should I do?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Tim
  • 115
  • 1
  • 9

1 Answers1

2

Your dataset's __getitem__ method should return a single element:

def __getitem__(self, index):
    return self.imgs[index], self.index[index]
Ivan
  • 34,531
  • 8
  • 55
  • 100