0

When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.

with strategy.scope():
  model = create_model()
  optimizer_inner = AdamW(weight_decay=1e-6)
  optimizer_middle = SWA(optimizer_inner)
  optimizer = Lookahead(optimizer_middle)
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

train_iterator = iter(train_dataset)
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.

for epoch in range(10):
  print('Epoch: {}/10'.format(epoch))


  train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

In Jax or Flax, however, I haven't seen a complete working example of doing so. I guess it would be something like

@jax.jit
def train_for_n_steps(train_state, batches):
    for batch in batches:
        train_state = train_step_fn(train_state, batch)
    return train_state

However, in my case when I am trying to test the complete example, I am not sure how one can create multiple batches. Here is a working example using GPU without training multiple steps. The relevant code should probably be here:

for step,batch in enumerate(train_ds.as_numpy_iterator()):

  # Run optimization steps over training batches and compute batch metrics
  state = train_step(state, batch) # get updated train state (which contains the updated parameters)
  state = compute_metrics(state=state, batch=batch) # aggregate batch metrics

  if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
    for metric,value in state.metrics.compute().items(): # compute metrics
      metrics_history[f'train_{metric}'].append(value) # record metrics
    state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch

    # Compute metrics on the test set after each training epoch
    test_state = state
    for test_batch in test_ds.as_numpy_iterator():
      test_state = compute_metrics(state=test_state, batch=test_batch)

    for metric,value in test_state.metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)

    print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
          f"loss: {metrics_history['train_loss'][-1]}, "
          f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
    print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
          f"loss: {metrics_history['test_loss'][-1]}, "
          f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")

My goal is to unroll 5 loops when training.

Any suggestions are welcomed.

RanWang
  • 310
  • 2
  • 12
  • I think JIT-ing the loop like you've said would be the way to do it. It'll make compilation take way longer and I don't know if it'll actually improve performance. – Davis Yoshida May 09 '23 at 21:04
  • Hi, I am not sure what is the exact way of doing it. How did you make it runnable? I just encountered a syntax error.... – RanWang May 10 '23 at 08:18
  • What did the error say? – Davis Yoshida May 10 '23 at 17:22
  • If you check the functions here, then there is no such thing as loop unrolling in the original jit code. https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit. – RanWang May 15 '23 at 23:30
  • The JIT works by tracing the execution of a python function. Because your loop runs a static number of times, the loop will effectively be unrolled. – Davis Yoshida May 15 '23 at 23:33
  • I am sorry but I am not sure what you mean by that. What I face is that I have a TPU which is very powerful and if I feed the loop only once, then its computational power is not fully realized. As is shown in the TensorFlow example, it is possible to send multiple batches to the TPU. – RanWang May 15 '23 at 23:56
  • Okay I'll just make a full answer – Davis Yoshida May 16 '23 at 00:29

1 Answers1

1

You could use more_itertools.chunked to get something like this:

for step, five_batches in chunked(train_ds.as_numpy_iterator()):
    state = five_steps(state, five_batches):

Then do the unrolling

@jax.jit
def five_steps(state, batches):
    for batch in batches:
        state = train_step(state, batch)
    return state

The reason this works is that batches has a length that isn't data dependent, so the loop will just get executed 5 times during tracing.

This will likely make jitting take much longer than you want, so the perferred but more difficult way is to pack the batches into [N x batch_size x ...] tensors, then use scan to loop your update function over the inputs.

Davis Yoshida
  • 1,757
  • 1
  • 10
  • 24
  • While you are writing, I have actually used this one. However, there is some weird behavior in the code. Most, importantly, it seems that the metrics are not computed. Here is the Colab notebook https://colab.research.google.com/gist/rwbfd/0a4111863f2d5ad1ddca654721ff8759/geek_bang_chapter_13_part2.ipynb – RanWang May 16 '23 at 01:04
  • I think you should divide `num_steps_per_epoch` by 5 since each update of `step` is now 5 steps. You could also do `step += 5`, but then your `if (step+1) % num_steps_per_epoch == 0:` check might be skipped – Davis Yoshida May 16 '23 at 01:08
  • Done. This is the complete notebook. https://colab.research.google.com/gist/rwbfd/cee859ae5ac6bd1cfa49fca250423f7f/geek_bang_chapter_13_part2.ipynb. – RanWang May 16 '23 at 01:44