I am trying to use keras
flow_from_directory to train a model. But it does not repeat
data after the epoch(i.e. when all the data has been iterated). I could not find any
option to do so either. Below is my code for data generation while training.
For example if total images = 70
batch_size = 32
then in 1st and 2nd iteration is gives 32 images, but in third it gives 6 images.
# data generation from directory without labels
trn = datagen.flow_from_directory(os.path.join(BASE, 'train_gen'),
batch_size=batch_size,
target_size=(inp_shape[:2]),
class_mode=None)
X = trn.next() # getting a batch of data.
I want the data generator to start repeating data after it's exhausted.
Actually I am trying to train a GAN, where a batch images are generated from Generator-Model and then it is concatenated with a batch of real images and then passed to Discriminator-Model and GAN-Model to train. I can't figure out how can I use fit_generator in this, Code is as below:
def train(self, inp_shape, batch_size=1, n_epochs=1000):
BASE = '/content/gdrive/My Drive/Dataset/GAN'
datagen = ImageDataGenerator(rescale=1./255)
trn_dist = datagen.flow_from_directory(os.path.join(BASE, 'train_gen'),
batch_size=batch_size,
target_size=(inp_shape[:2]),
seed = 1360000,
class_mode=None)
val_dist = datagen.flow_from_directory(os.path.join(BASE, 'test_gen'),
batch_size=batch_size,
target_size=(inp_shape[:2]),
class_mode=None)
trn_real = datagen.flow_from_directory(os.path.join(BASE, 'train_real'),
batch_size=batch_size,
target_size=(inp_shape[:2]),
seed = 1360000,
class_mode=None)
for e in range(n_epochs):
real_images = trn_real.next()
dist_images = trn_dist.next()
gen_images = self.generator.predict(dist_images)
factor = inp_shape[0]/250
gen_res = ndi.zoom(gen_images, (1, factor, factor, 1), order=2)
X = np.concatenate([real_images, gen_res])
y = np.zeros(2*batch_size)
y[:batch_size] = 1.
self.discriminator.trainable = True
self.discriminator.fit(X, y, batch, n_epochs)
self.discriminator.trainable = False
self.model.fit(gen_res, y[:batch_size])
print ('> training --- epoch=%d/%d' % (e, n_epochs))
if e > 0 and e % 2000 == 0:
self.model.save('%s/models/gan_model_%d_.h5'%(BASE, e))
PS: I am new to Gans please correct me if I am doing something wrong.