0

I am using Tensorflow 2.3 and trying to save model checkpoint after n number of epochs. n can be anything but for now trying with 10

Per this thread, I tried save_freq = 'epoch' and period = 10 which works but since period parameter is deprecated, I wanted to try an alternative approach.

HEIGHT = 256
WIDTH = 256
CHANNELS = 3
EPOCHS = 100
BATCH_SIZE = 1
SAVE_PERIOD = 10

n_monet_samples = 21

checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_freq=SAVE_PERIOD * (n_monet_samples//BATCH_SIZE)
)

If I use save_freq=SAVE_PERIOD * (n_monet_samples//BATCH_SIZE) for the checkpoint callback definition, I get error

ValueError: Unrecognized save_freq: 210

I am not sure why since per Keras callback code, as long as save_freq is in epochs or in integer, it should be good.

Please suggest.

nad
  • 2,640
  • 11
  • 55
  • 96

1 Answers1

0

It does not show any error to me when I tried the same code in same Tensorflow version==2.3:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
BATCH_SIZE = 1
SAVE_PERIOD = 10

n_monet_samples = 21

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1, save_freq=SAVE_PERIOD * (n_monet_samples//BATCH_SIZE))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=20,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])

Output:

Epoch 1/20
32/32 [==============================] - 0s 14ms/step - loss: 1.1152 - sparse_categorical_accuracy: 0.6890 - val_loss: 0.6934 - val_sparse_categorical_accuracy: 0.7940
Epoch 2/20
32/32 [==============================] - 0s 9ms/step - loss: 0.4154 - sparse_categorical_accuracy: 0.8840 - val_loss: 0.5317 - val_sparse_categorical_accuracy: 0.8330
Epoch 3/20
32/32 [==============================] - 0s 8ms/step - loss: 0.2787 - sparse_categorical_accuracy: 0.9270 - val_loss: 0.4854 - val_sparse_categorical_accuracy: 0.8400
Epoch 4/20
32/32 [==============================] - 0s 8ms/step - loss: 0.2230 - sparse_categorical_accuracy: 0.9420 - val_loss: 0.4525 - val_sparse_categorical_accuracy: 0.8590
Epoch 5/20
32/32 [==============================] - 0s 10ms/step - loss: 0.1549 - sparse_categorical_accuracy: 0.9620 - val_loss: 0.4275 - val_sparse_categorical_accuracy: 0.8650
Epoch 6/20
32/32 [==============================] - 0s 10ms/step - loss: 0.1110 - sparse_categorical_accuracy: 0.9770 - val_loss: 0.4251 - val_sparse_categorical_accuracy: 0.8630
Epoch 7/20
11/32 [=========>....................] - ETA: 0s - loss: 0.0936 - sparse_categorical_accuracy: 0.9886
Epoch 00007: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 14ms/step - loss: 0.0807 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4248 - val_sparse_categorical_accuracy: 0.8610
Epoch 8/20
32/32 [==============================] - 0s 10ms/step - loss: 0.0612 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4058 - val_sparse_categorical_accuracy: 0.8650
Epoch 9/20
32/32 [==============================] - 0s 8ms/step - loss: 0.0489 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4393 - val_sparse_categorical_accuracy: 0.8610
Epoch 10/20
32/32 [==============================] - 0s 6ms/step - loss: 0.0361 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4150 - val_sparse_categorical_accuracy: 0.8620
Epoch 11/20
32/32 [==============================] - 0s 10ms/step - loss: 0.0294 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4090 - val_sparse_categorical_accuracy: 0.8670
Epoch 12/20
32/32 [==============================] - 0s 7ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4365 - val_sparse_categorical_accuracy: 0.8600
Epoch 13/20
32/32 [==============================] - 0s 8ms/step - loss: 0.0203 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4231 - val_sparse_categorical_accuracy: 0.8620
Epoch 14/20
 1/32 [..............................] - ETA: 0s - loss: 0.0115 - sparse_categorical_accuracy: 1.0000
Epoch 00014: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 9ms/step - loss: 0.0164 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4263 - val_sparse_categorical_accuracy: 0.8650
Epoch 15/20
32/32 [==============================] - 0s 7ms/step - loss: 0.0128 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4260 - val_sparse_categorical_accuracy: 0.8690
Epoch 16/20
32/32 [==============================] - 0s 7ms/step - loss: 0.0120 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4194 - val_sparse_categorical_accuracy: 0.8740
Epoch 17/20
32/32 [==============================] - 0s 9ms/step - loss: 0.0110 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4302 - val_sparse_categorical_accuracy: 0.8710
Epoch 18/20
32/32 [==============================] - 0s 6ms/step - loss: 0.0090 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4331 - val_sparse_categorical_accuracy: 0.8660
Epoch 19/20
32/32 [==============================] - 0s 7ms/step - loss: 0.0084 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4320 - val_sparse_categorical_accuracy: 0.8760
Epoch 20/20
16/32 [==============>...............] - ETA: 0s - loss: 0.0074 - sparse_categorical_accuracy: 1.0000
Epoch 00020: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 13ms/step - loss: 0.0072 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4280 - val_sparse_categorical_accuracy: 0.8750
<tensorflow.python.keras.callbacks.History at 0x7f90f0082cd0>

As you already know save_freq is equal to 'epoch' or integer.
When using 'epoch', the callback saves the model after each epoch.
When using integer, the callback saves the model at end of theses many batches(end of these many steps_per_epoch).

As above definition of save_freq, checkpoints saves every after 210 steps.

Please check this for more details on ModelCheckpoint Arguments.