7

I have a large dataset that can fit in host memory. However, when I use tf.keras to train the model, it yields GPU out-of-memory problem. Then I look into tf.data.Dataset and want to use its batch() method to batch the training dataset so that it can execute the model.fit() in GPU. According to its documentation, an example is as follows:

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Is the BATCH_SIZE in dataset.from_tensor_slices().batch() the same as the batch_size in the tf.keras modelt.fit()?

How should I choose BATCH_SIZE so that GPU has sufficient data to run efficiently and yet its memory is not overflown?

Timbus Calin
  • 13,809
  • 5
  • 41
  • 59
David293836
  • 1,165
  • 2
  • 18
  • 36

1 Answers1

5

You do not need to pass the batch_size parameter in model.fit() in this case. It will automatically use the BATCH_SIZE that you use in tf.data.Dataset().batch().

As for your other question : the batch size hyperparameter indeed needs to be carefully tuned. On the other hand, if you see OOM errors, you should decrease it until you do not get OOM (normally (but not necessarily) in this manner 32 --> 16 --> 8 ...). In fact you can try non-power of two batch sizes for the decrease purposes.

In your case I would start with a batch_size of 2 an increase it gradually (3-4-5-6...).

You do not need to provide the batch_size parameter if you use the tf.data.Dataset().batch() method.

In fact, even the official documentation states this:

batch_size : Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).

Timbus Calin
  • 13,809
  • 5
  • 41
  • 59
  • Thanks. But with Dataset and batch(), I don't have GPU OOM problem anymore. I still use batch_size=32. Any comments? – David293836 Jul 01 '20 at 06:30
  • I am updating the comment up so that you can see more details – Timbus Calin Jul 01 '20 at 07:05
  • 1
    Got it. Thanks for sharing the info from the official documentation. It has cleared some of my confusions! However, one thing worth noting is that instead of changing batch_size from 32->16->8, I change from 32->64->128->264 with the Dataset batch(). I don't run into OOM anymore. With larger batch_size values, the execution time is greatly reduced and the performance metrics are the same. – David293836 Jul 01 '20 at 18:43
  • The metrics with larger batch sizes are... debatable. A lot of papers and discussions on that topic. – Timbus Calin Jul 01 '20 at 19:50
  • When you have 2M rows, it is doubtful that batch_size = 32 or 64 would make any difference. This is from my experience, FWIW. – David293836 Jul 01 '20 at 20:35
  • Depending on what type of data you have, the complexity of the network etc etc – Timbus Calin Jul 02 '20 at 04:16