I am kinda new to TensorFlow world but have written some programs in Keras. Since TensorFlow 2 is officially similar to Keras, I am quite confused about what is the difference between tf.keras.callbacks.ModelCheckpoint and tf.train.Checkpoint. If anybody can shed light on this, I would appreciate it.
-
isn't it a bit obvious? If you are using Keras, which is a high-level layer to run over TensorFlow, you would save checkpoints using Keras callback. I am sure that actually the Keras callback would instruct the TensorFlow checkpoint because as said before, Keras is just a layer over TF.... – neel g Apr 16 '20 at 13:40
-
2@neelg it's far from obvious how to perform perfect state save and reload including all the optimizer states using the ModelCheckpoint callback. I only managed to achieve this with tf.train.Checkpoint. With ModelCheckpoint one seems to be forced to save the entire model, including the graph info in order to save the optimizer properly but this takes very long every time a checkpoint is saved. – isarandi Sep 06 '21 at 15:00
3 Answers
It depends on whether a custom training loop is required. In most cases, it's not and you can just call model.fit()
and pass tf.keras.callbacks.ModelCheckpoint
. If you do need to write your custom training loop, then you have to use tf.train.Checkpoint
(and tf.train.CheckpointManager
) since there's no callback mechanism.

- 499
- 4
- 8
-
It is not true. Even when we build a custom training loop, we may use Keras callbacks. Refer to the following page: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback – chanwcom Oct 10 '21 at 10:19
TensorFlow is a 'computation' library and Keras is a Deep Learning library which can work with TF or PyTorch, etc. So what TF provides is a more generic not-so-customized-for-deep-learning version. If you just compare the docs you can see how more comprehensive and customized ModelCheckpoint
is. Checkpoint just reads and writes stuff from/to disk. ModelCheckpoint
is much smarter!
Also, ModelCheckpoint
is a callback. It means you can just make an instance of it and pass it to the fit
function:
model_checkpoint = ModelCheckpoint(...)
model.fit(..., callbacks=[..., model_checkpoint, ...], ...)
I took a quick look at Keras's implementation of ModelCheckpoint
, it calls either save
or save_weights
method on Model
which in some cases uses TensorFlow's CheckPoint
itself. So it is not a wrapper per se but certainly is on a lower level of abstraction -- more specialized for saving Keras models.

- 4,102
- 3
- 29
- 49
-
So, TensorFlow's `CheckPoint` is more customizable and Keras `ModelCheckpoint` is the high-level layer on it? – superduper Apr 17 '20 at 00:20
-
Well, it has less features. Depends on how you want to customize it. – Mohammad Jafar Mashhadi Apr 17 '20 at 04:21
-
-
1Because it is not specialized for saving models, it is more abstract and has to have less features – Mohammad Jafar Mashhadi Apr 17 '20 at 04:25
-
I think that Tensorflow `CheckPoint` is the way of doing it in TensorFlow only that's why it needs more work and you manually have to add the functionalities, whereas Keras is more high level it automatically does a lot of work for you. – superduper Apr 17 '20 at 04:52
-
That's right. TF is for graph computations in general so it have the bare minimums, it's by design. – Mohammad Jafar Mashhadi Apr 17 '20 at 05:33
I also had a hard time differentiating between the checkpoint objects used when I looked at other people's code, so I wrote down some notes about when to use which one and how to use them in general. Either-way, I think it might be useful for other people having the same issue:
Saving model Checkpoints
These are 2 ways of saving your model's checkpoints, each is for a different use case:
1) Checkpoint & CheckpointManager
This is use-full when you are managing the training loops yourself.
You use them like this:
1.1) CheckpointDefinition from the docs: "A Checkpoint object can be constructed to save either a single or group of trackable objects to a checkpoint file".
How to initialise it:
- You can pass it key value pairs for:
- All the custom function calls or objects that make up your model and you want to keep track of:
- Like a generator, discriminiator, loss function, optimizer etc
ckpt = Checkpoint(discr_opt=discr_opt,
genrt_opt=genrt_opt,
wgan = wgan,
d_model = d_model,
g_model = g_model)
1.2) CheckpointManager
This literally manages the checkpoints you have defined to be stored at a location and things like how many to to keep. Definition from the docs: "Manages multiple checkpoints by keeping some and deleting unneeded ones"
How to initialise it:
- Initialise it with the CheckPoint object you create as first argument.
- The directory where to save the checkpoint files.
- And you probably want to define how much you want to keep, since this can be a lot of complex models
manager = CheckpointManager(ckpt, "training_checkpoints_wgan", max_to_keep=3)
How to use it:
- We have setup the manager object with our specified checkpoints, so it's ready to use.
- Call this at the end of each training epoch
manager.save()
2) ModelCheckpoint (callback)
You want to use this callback when you are not managing epoch iterations yourself. For example when you have setup a relatively simple Sequential model and you call model.fit(), which manages the training process for you.
Definition from the docs: "Callback to save the Keras model or model weights at some frequency."
How to initialise it:
Pass in the path where to save the model
The option save_weights_only is set to False by default:
- If you want to only save the weights make sure to update this
The option save_best_only is set to False by default:
- If you want to only save the best model instead of all of them, you can set this to True.
verbose is set to 0 (False), so you can update this to 1 to validate it
mc = ModelCheckpoint("training_checkpoints/cp.ckpt", save_best_only=True, save_weights_only=False)
How to use it:
- The model checkpoint callback is now ready to for training.
- You pass in the object in you into your callbacks list when you fit the model:
model.fit(X, y, epochs=100, callbacks=[mc])

- 837
- 8
- 9