0

I have raw data images saved in separate CSV files(each image in a file). I want to train a CNN on them using PyTorch. how should I load data to be appropriate for using as CNN's input? (also, it is 1 channel and the image net's input is RGB as the default)

forouzanf
  • 15
  • 3

1 Answers1

0

PyTorch's DataLoader, as the name suggests, is simply a utility class that helps you load your data in parallel, build your batch, shuffle and so on, what you need is instead a custom Dataset implementation.

Ignoring the fact that images stored in CSV files is kind of weird, you simply need something of the sort:

from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):

    def __init__(self, path: Path, ...):
        # do some preliminary checks, e.g. your path exists, files are there...
        assert path.exists()
        ...
        # retrieve your files in some way, e.g. glob
        self.csv_files = list(glob.glob(str(path / "*.csv")))

    def __len__(self) -> int:
        # this lets you know len(dataset) once you instantiate it
        return len(self.csv_files)


    def __getitem__(self, index: int) -> Any:
        # this method is called by the dataloader, each index refers to
        # a CSV file in the list you built in the constructor
        csv = self.csv_files[index]
        # now do whatever you need to do and return some tensors
        image, label = self.load_image(csv)
        return image, label

And that's it, more or less. You can then create your dataset, pass it to a dataloader and iterate the dataloader, something like:

dataset = CustomDataset(Path("path/to/csv/files"))
train_loader = DataLoader(dataset, shuffle=True, num_workers=8,...)

for batch in train_loader:
    ...
edornd
  • 441
  • 1
  • 4
  • 18