I am trying to train inception net on a custom dataset of fruits which are of two classes 0 and 1. This is mentioned in csv file. Contents of the file can be seen below:
fruit0626.JPG 0
fruit0006.JPG 0
fruit0007.JPG 0
fruit0008.JPG 0
fruit0009.JPG 0
fruit0010.JPG 0
fruit0001.JPG 0
fruit0002.JPG 0
fruit0003.JPG 0
fruit0004.JPG 0
fruit0005.JPG 0
fruit0186.JPG 0
fruit0187.JPG 0
fruit0188.JPG 0
fruit0189.JPG 0
fruit0190.JPG 0
fruit0141.JPG 0
fruit0142.JPG 0
fruit0143.JPG 0
fruit0144.JPG 0
fruit0145.JPG 0
fruit0146.JPG 0
fruit0147.JPG 0
fruit0148.JPG 0
fruit0149.JPG 0
fruit0150.JPG 0
fruit0031.JPG 0
fruit0032.JPG 0
fruit0033.JPG 0
fruit0034.JPG 0
fruit0035.JPG 0
fruit0156.JPG 0
fruit0157.JPG 0
fruit0158.JPG 0
fruit0159.JPG 0
fruit0160.JPG 0
each row has the image name and which class it belongs to (0 or 1).
Below is the code that i have used:
from google.colab import drive
drive.mount('/content/drive')
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from skimage import io
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
class kinoo(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return(len(self.annotations))
def __getitem__(self, index):
image_path = os.path.join(self.root_dir, self.annotations.iloc[index,0])
image = io.imread(image_path)
y_label = torch.tensor(int(self.annotations.iloc[index,1]))
if self.transform:
image = self.transform(image)
return (image, y_label)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#hyperparameters
in_channel = 3
learning_rate = 1e-3
batch_size = 32
num_epochs = 50
#transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
transforms.Resize((640,640))
])
#load data
dataset = kinoo(csv_file = '/content/drive/MyDrive/aza-exp/aza-fruit-regression-v2/fruit_classification.csv', root_dir = '/content/drive/MyDrive/aza-exp/aza-fruit-regression-v2', transform = transform)
train_set, test_set = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(dataset = train_set, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_set, batch_size = batch_size, shuffle = True)
print(list(train_set))
This gave me the following error:
I found two similar questions with same error: link 1 link 2
But these solutions didn't worked for me. I've searched the internet and couldn't troubleshoot it on my own. Any help regarding why this error is generated and how i can solve it, would be greatly appreciated.