0

I'm trying to load checkpoints and save average weights of them using TF2.1. I found TF1 version for it. https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py

A variable "checkpoints" is a list of checkpoint paths

  # Read variables from all checkpoints and average them.
  logger.info("Reading variables and averaging checkpoints:")
  for c in checkpoints:
    logger.info(c)

  var_list = tf.train.list_variables(checkpoints[0])

  var_values, var_dtypes = {}, {}
  for (name, shape) in var_list:
    if not name.startswith("global_step"):
      var_values[name] = tf.zeros(shape)

  for checkpoint in checkpoints:
    reader = tf.train.load_checkpoint(checkpoint)
    for name in var_values:
      tensor = tf.convert_to_tensor(reader.get_tensor(name))

      if tensor.dtype == tf.string:
        var_values[name] = tensor
      else:
        var_values[name] = tf.cast(var_values[name], tensor.dtype)
        var_values[name] += tensor
      var_dtypes[name] = tensor.dtype
    logger.info("Read from checkpoint %s", checkpoint)

  for name in var_values:  # Average.
    if var_dtypes[name] != tf.string:
      var_values[name] /= len(checkpoints)

Cold you explain how to save the average var_values to a checkpoint?

1 Answers1

0

I could save the average checkpoint by referring to the Keras version of the same question since Tensorflow 2.1 follows Keras API.

URL: Average weights in keras models