Following this great post: Scaling Keras Model Training to Multiple GPUs I tried to upgrade my model to run in parallel on my multiple GPUs instance.
At first I ran the MNIST example as proposed here: MNIST in Keras with the additional syntax in the compile command as follows:
# Prepare the list of GPUs to be used in training
NUM_GPU = 8 # or the number of GPUs available on your machine
gpu_list = []
for i in range(NUM_GPU): gpu_list.append('gpu(%d)' % i)
# Compile your model by setting the context to the list of GPUs to be used in training.
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'],
context=gpu_list)
then I trained the model:
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
So far so good. It ran for less than 1s per epoch and I was really excited and happy until I tried - data augmentation.
To that point, my training images were a numpy array at size (6000,1,28,28) and the labels were at size (10,60000) - one-hot encoded. For data augmentation I used the ImageDataGenerator function:
gen = image.ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,
height_shift_range=0.08, zoom_range=0.08)
batches = gen.flow(x_train, y_train, batch_size=NUM_GPU*64)
test_batches = gen.flow(x_test, y_test, batch_size=NUM_GPU*64)
and then:
model.fit_generator(batches, batches.N, nb_epoch=1,
validation_data=test_batches, nb_val_samples=test_batches.N)
And unfortunately, from 1s per epoch I started getting ~11s per epoch... I suppose that the "impact" of the ImageDataGenerator is destructive and it probably running all the (reading->augmenting->writing to gpu) process really slow and inefficient.
Scaling keras to multiple GPUs is great, but data-augmentation is essential for my model to be robust enough.
I guess one solution could be: load all images from directory and write your own function that shuffles and augment those images. But I'm sure must be some easier way to optimize this process using keras API.
Thanks!