2

I have a training dataset of 600 images with (512*512*1) resolution categorized into 2 classes(300 images per class). Using some augmentation techniques I have increased the dataset to 10000 images. After having following preprocessing steps

all_images=np.array(all_images)/255.0
all_images=all_images.astype('float16')
all_images=all_images.reshape(-1,512,512,1)
saved these images to H5 file.

I am using an AlexNet architecture for classification purpose with 3 convolutional, 3 overlap max-pool layers. I want to know which of the following cases will be best for training using Google Colab where memory size is limited to 12GB.

1. model.fit(x,y,validation_split=0.2)
# For this I have to load all data into memory and then applying an AlexNet to data will simply cause Resource-Exhaust error.

2. model.train_on_batch(x,y)
# For this I have written a script which randomly loads the data batch-wise from H5 file into the memory and train on that data. I am confused by the property of train_on_batch() i.e single gradient update. Do this will affect my training procedure or will it be same as model.fit().

3. model.fit_generator() 
# giving the original directory of images to its data_generator function which automatically augments the data and then train using model.fit_generator(). I haven't tried this yet. 

Please guide me which will be the best among these methods in my case. I have read many answers Here, Here, and Here about model.fit(), model.train_on_batch() and model.fit_generator() but I am still confused.

zeeshan nisar
  • 553
  • 2
  • 4
  • 18

1 Answers1

0

model.fit - suitable if you load the data as numpy-array and train without augmentation. model.fit_generator - if your dataset is too big to fit in the memory or\and you want to apply augmentation on the fly. model.train_on_batch - less common, usually used when training more than one model at a time (GAN for example)

Jenia Golbstein
  • 374
  • 2
  • 12
  • Oh yes, I got you. Thank you so much for your time and help. Please guide me on the last thing that can I use model.fit_generator() for my saved augmented data i.e the data which I have augmented and saved in numpy arrays to an H5 file. – zeeshan nisar Dec 25 '18 at 12:37
  • load random batch_wise data from that H5 file into memory lets suppose batchsize of 64 and then train using fit_generator() – zeeshan nisar Dec 25 '18 at 12:39