Even as of Keras 1.2.2, referencing merge, it does have multiprocessing included, but model.fit_generator()
is still about 4-5x slower than model.fit()
due to disk reading speed limitations. How can this be sped up, say through additional multiprocessing?
Asked
Active
Viewed 7,991 times
9

Marcin Możejko
- 39,542
- 10
- 109
- 120

mikal94305
- 4,663
- 8
- 31
- 40
-
1It depends where the bottleneck is... If it's reading speed limitations, increase the batch_size to slow down your training step and increase the queue size and nb of worker. Are you training on GPU or CPU? – Nassim Ben Mar 07 '17 at 07:35
-
It would be also great if you provide the details about your data, batch size, kind of loading, etc. – Marcin Możejko Mar 07 '17 at 10:09
-
Training is on GPU. I've changed batch size from 32, 64, to 128, and there are no significant differences in speed. – mikal94305 Mar 07 '17 at 22:04
-
It is supposed to be slower by design. There is a lot of I/O related overhead in `fit_generator` that is not present in `fit()`. A SSD may be the way to mitigate that. – Jun 30 '19 at 11:17
2 Answers
3
You may want to check out the workers
and max_queue_size
parameters of fit_generator()
in the documentation. Essentially, more workers
creates more threads for loading the data into the queue that feeds data to your network. There is a chance that filling the queue might cause memory problems, though, so you might want to decrease max_queue_size
to avoid this.

Mach_Zero
- 504
- 3
- 10
1
I had a similar problem where I switched to dask to load the data into memory rather than using a generator where I was using pandas. So, depending on your data size, if possible, load the data into memory and use the fit function.

nafizh
- 185
- 3
- 14