2

I used the following code to create a checkpoint manager outside of the loop that I train my model:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(object_1=object_1)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

Then while training the model, I use ckpt_save_path = ckpt_manager.save() to save the variables after each epoch.

Given that I want to implement an early stopping approach, I need to restore all the variables after a specific epoch and use those variables to do a prediction. How can I restore the variable after let's say epoch e if I have used the above code to save the variables (hope the saving process is correct?). I know I can first create the same variables and objects and then use the following code to restore the latest checkpoint, but have no idea how to restore specific checkpoints (like the variables after epoch number e) and not the latest.

ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed()

Thanks,

khemedi
  • 774
  • 3
  • 9
  • 19

1 Answers1

3

Yes, you need to generate a text string of file name with epoch number.

c_manager = tf.train.CheckpointManager(checkpoint, ...)

if EPOCH == '':
    if c_manager.latest_checkpoint:
        tf.print("-----------Restoring from {}-----------".format(
            c_manager.latest_checkpoint))
        checkpoint.restore(c_manager.latest_checkpoint)
        EPOCH = c_manager.latest_checkpoint.split(sep='ckpt-')[-1]
    else:
        tf.print("-----------Initializing from scratch-----------")
else:    
    checkpoint_fname = CHECKPOINT_SAVE_DIR + 'ckpt-' + str(EPOCH)
    tf.print("-----------Restoring from {}-----------".format(checkpoint_fname))
    checkpoint.restore(checkpoint_fname)
EyesBear
  • 1,376
  • 11
  • 21
  • 1
    Thanks! I think another thing which is worth adding is that I should use ```ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=EPOCHS)``` instead of ``` ```ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)``` when creating the ckpt_manager for saving the model during training. Note, EPOCHS is the total number of epochs in training and EPOCH is a epoch after which the validation loss is the minimum. – khemedi Jul 17 '20 at 13:36