9

I am looking for the best approach to train on larger-than-memory-data in Keras and currently noticing that the vanilla ImageDataGenerator tends to be slower than I would hope.

I have two networks training on the Kaggle cat's vs dogs dataset (25000 images):

1) this approach is exactly the code from: http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/

2) same as (1) but using an ImageDataGenerator instead of loading into memory the data

Note: for below, "preprocessing" means resizing, scaling, flattening

I find the following on my gtx970:

For network 1, it takes ~0s per epoch.

For network 2, it takes ~36s per epoch if the preprocessing is done in the data generator.

For network 2, it takes ~13s per epoch if preprocessing is done in a first-pass outside of the data generator.

Is this likely the speed limit for ImageDataGenerator (13s seems like the usual 10-100x difference between disk and ram...)? Are there approaches/mechanisms better suited for training on larger-than-memory-data when using Keras? e.g. Perhaps there is way to get the ImageDataGenerator in Keras to save its processed images after the first epoch?

Thanks!

John Cast
  • 1,771
  • 3
  • 18
  • 40
  • 1
    While a little old now, this post is relevant: [Slow image data generator](https://github.com/keras-team/keras/issues/2394). The posts suggests Keras (at least at some point in the past) applied several sequential transformations when a single transformation could have been used. – user3731622 Nov 22 '18 at 02:09
  • 1
    See this: https://github.com/stratospark/keras-multiprocess-image-data-generator/blob/master/Accelerating%20Deep%20Learning%20with%20Multiprocess%20Image%20Augmentation%20in%20Keras.md – Amir Saniyan Oct 28 '19 at 19:31

2 Answers2

3

I assume you already might have solved this, but nevertheless...

Keras image preprocessing has the option of saving the results by setting the save_to_dir argument in the flow() or flow_from_directory() function:

https://keras.io/preprocessing/image/

Christian Gollhardt
  • 16,510
  • 17
  • 74
  • 111
petezurich
  • 9,280
  • 9
  • 43
  • 57
3

In my understanding, problem is that augmented images are used only once in a training cycle of a model, not even across several epochs. So it's a huge waste of GPU cycles while CPU is struggling. I found following solution:

  1. I generate as many augmentations in RAM as I can
  2. I use them for training across a frame of epochs, 10 to 30, whatever it takes to get a noticeable convergence
  3. after that I generate new batch of augmented images (by implementing on_epoch_end) and process goes on.

This approach most of the time keeps GPU busy, while being able to benefit from data augmentation. I use custom Sequence subclass to generate augmentation and fix classes imbalance at the same time.

EDIT: adding some code to clarify the idea

from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2

class StoppingFromFile(tensorflow.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if read_config_file('control.ini','ML','stop',globals()):        
            if stop is not None:        
                if stop==True or stop=='True':
                    logging.warning(f'Model should be stopped according to the control fole')
                    self.model.stop_training = True

class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
    def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
                 num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
                disk_saving_prob:float=.01,disk_example_nfiles:int=50):
        """
            From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
            If current class is too scarce, ensures that current frame has no duplicate final images.
            If it's rich enough, ensures that current frame has no duplicate base images.
        
        """
        logging.info(f'Got {len(images_and_classes)} classes.')
        self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
        
        self.images_and_classes=images_and_classes        
        self.num_class_samples=num_class_samples
        self.augmentations_fn=augmentations_fn
        self.preprocessing_fn=preprocessing_fn
        
        self.is_validation=is_validation
        self.frame_length=frame_length                    
        self.batch_size = batch_size      
        self.class_sizes=class_sizes
        self.input_size=input_size        
        self.aug_pipe_p=aug_pipe_p
        self.aug_p=aug_p        
        self.images=None
        self.epoch = 0
        #print(f'got frame_length={self.frame_length}')
        self._generate_data()
        

    def __len__(self):
        return int(np.ceil(len(self.images)/ float(self.batch_size)))

    def __getitem__(self, idx):
        a=idx * self.batch_size;b=a+self.batch_size
        return self.images[a:b],self.labels[a:b]
    
    def on_epoch_end(self):
        import ast
        self.epoch += 1    
        mydict={}

        import pathlib
        fname='control.json'
        p = pathlib.Path(fname)
        if p.is_file():
            try:
                with open (fname) as f:
                    mydict=json.load(f)
                for var,val in mydict.items():
                    if hasattr(self,var):
                        converted = val #ast.literal_eval(val)
                        if converted is not None:
                            if getattr(self, var)!=converted:
                                setattr(self, var, converted)                                        
                                print(f'{var} became {val}')
            except Exception as e:
                logging.error(str(e))
        if self.epoch % self.frame_length == 0:
            #print('generating data...')
            self._generate_data()
            
    def _add_sample(self,image,label):
        from random import random
        idx=self.indices[self.img_sent]
        
        if self.disk_saving_prob>0:
            if random()<self.disk_saving_prob:
                self.cur_example_file+=1
                if self.cur_example_file>self.disk_example_nfiles:
                    self.cur_example_file=1
                Path(r'example_images/').mkdir(parents=True, exist_ok=True)
                cv2.imwrite(f'example_images/test{self.cur_example_file}.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
        
        if self.preprocessing_fn: 
            self.images[idx]=self.preprocessing_fn(image)
        else:
            self.images[idx]=image
        
        self.labels[idx]=label
        self.img_sent+=1        
        
    def _generate_data(self):
        logging.info('Generating new set of augmented data...')
        
        collect()
        #del self.images
        #del self.labels        
        #collect()
        
        if self.num_class_samples:
            expected_length=len(self.images_and_classes)*self.num_class_samples
        else:
            expected_length=sum(self.class_sizes.values())        
            
        if self.images is None:
            self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
            self.labels=np.empty((expected_length),np.int32)
        
        self.indices=np.random.choice(expected_length, expected_length, replace=False)
        self.img_sent=0
        
        
        collect()
        
        relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
        maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
        
        #for each class
        x,y=[],[]
        nartificial=0
        for label,images in tqdm(self.images_and_classes.items()):
            if self.num_class_samples is None:
                #Just all native samples without augmentations
                for image in images:
                    self._add_sample(image,label)                        
            else:
                #if there are enough native samples
                if len(images)>=self.num_class_samples:
                    #randomly select samples of this class which will participate in this frame of epochs                
                    indices=np.random.choice(len(images), self.num_class_samples, replace=False)
                    #apply albumentations pipeline to selected samples

                    for idx in indices:
                        if not self.is_validation:
                            self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
                        else:
                            self._add_sample(images[idx],label)
                                                    
                else:
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    # Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    hashes=set()
                    norig=0
                    while len(hashes)<self.num_class_samples:
                        if self.is_validation and norig<len(images):
                            #just include all originals first
                            image=images[norig]
                        else:
                            image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']                                                      
                        next_hash=np.sum(image)
                        if next_hash not in hashes or (self.is_validation and norig<=len(images)):                        
                            
                            #print(f'Adding orig {norig} out of {self.num_class_samples}, hashes={hashes}')
                            
                            self._add_sample(image,label)
                            if next_hash in hashes:
                                norig+=1
                                hashes.add(norig)
                            else:
                                hashes.add(next_hash)
                                nartificial+=1  
                                
        
        #self.images=self.images[indices];self.labels=self.labels[indices]                              
        
        logging.info(f'Generated {self.img_sent} samples ({nartificial} artificial)')

once I have images and classes loaded,

train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
                          input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
    augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)

val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
                                        input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
    augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)

and after the model is instantiated, I do

model.fit(train_datagen,epochs=600,verbose=1,
          validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
          callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)
Anatoly Alekseev
  • 2,011
  • 24
  • 27