-1

When I run my code, the train loop never finishes. When it prints out, telling where it is, it has way exceeded the 300 Datapoints, which I told the program there to be, but also the 42000, which are actually there in the csv file. Why doesn't it stop automatically after 300 Samples?

Thanks guys.

My Code: (I left out the Net and the test loop for readability)

import torch
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
from torch import nn
from torchvision.transforms import ToTensor
#import os
import pandas as pd
#import numpy as np
import random
import time


#Hyperparameters
batch_size = 3
learning_rate = 8e-3



#DataSet
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, batches):
        self.img_dir = img_dir
        self.batches =batches
        self.data=pd.read_csv(self.img_dir)
        #self.data=pd.read_csv("01.Actual/02.NeuralNetwork/01.OffTopicTests/NumberRecognizer/train.csv")


    def __len__(self):
        #return len(self.data)
        return 300

    def __getitem__(self, idx):
        
        
        images =[]
        labels = torch.zeros(self.batches,dtype=int)
        for x in range(self.batches):
            label =self.data.at[(idx+x),'label']
            label = label.item()
            #image = torch.zeros(1,1,784) #,dtype=torch.int32
            image = torch.zeros(1,28,28) #,dtype=torch.int32

            for i in range(784):

                z = int(i%28)
                y= int((i-x)/28)

                column = 'pixel' +str(i)

                #image[0,0,i]=self.data.at[(idx+x),column]
                image[0,z,y]=self.data.at[(idx+x),column]
            
            images.append(image)
            labels[x] = label
            
        return torch.stack(images), labels
        
#DataLoader
train_loader=DataLoader(CustomImageDataset,batch_size, shuffle = False, drop_last= True)

#Creating Instances
Data =CustomImageDataset("01.Actual/02.NeuralNetwork/01.OffTopicTests/NumberRecognizer/train.csv",batch_size)
model =NeuralNetwork()


#Hyperparameters


epochs = 1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


#Creating training loop
def train_loop(dataloader,model,loss_fn,optimizer,batch_size):
    size=Data.__len__()
    
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        


        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



        if batch % 100 == 0:
            #print(len(X))
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")    
       


#Executing Part of the script
for t in range(epochs):
    print(f"Epoch {t+1}\n ----------------")
    train_loop(Data,model,loss_fn,optimizer,batch_size)
    test_loop(Data,model,loss_fn)

tridentifer
  • 29
  • 1
  • 4
  • Does this answer your question? [Taking subsets of a pytorch dataset](https://stackoverflow.com/questions/47432168/taking-subsets-of-a-pytorch-dataset) – maciek97x Apr 14 '23 at 09:23
  • Not quite . So the problem is, that the training loop continues to run until it breaks, even if i split the data. Somehow, the enumerate Loop doesn't stop in time. – tridentifer Apr 14 '23 at 10:06

1 Answers1

0

The problem is that when you define __getitem__ then iterating over this class does not use __len__ at all. If you want to limit the number of samples, then you should do it inside __getitem__ by raising StopIteration.

Example:

class Example(object):
    def __init__(self, a, b, c):
        self.data = list(range(a))
        self.len = b
        self.stop = c

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        if idx >= self.stop:
            raise StopIteration
        return self.data[idx]

Results:

In [1]: for x in Example(8, 4, 4):
   ...:     print(x, end=', ')
0, 1, 2, 3, 

In [2]: for x in Example(8, 4, 8):
   ...:     print(x, end=', ')
0, 1, 2, 3, 4, 5, 6, 7, 

In [3]: for x in Example(4, 4, 8):
   ...:     print(x, end=', ')
0, 1, 2, 3,  
maciek97x
  • 2,251
  • 2
  • 10
  • 21