3

I have trained a seq2seq tensorflow model for translating a sentence from English to Spanish. I trained a model for 615 700 steps, and save the model checkpoints successfully. My training data size for both English and Spanish sentences is 200 000. I want to retrain this model for 10K new data sentences from 615 700 steps. I am using sequence to sequence tensoflow model for this. How can I start retrain model from the last checkpoint? Here is the link that I am usingfor the translation.

I have 3 types of files in my train folder:

.index
.meta
.data
and checkpoint file.

My new training data set files are europarl_train.es-en.en and europarl_train.es-en.esfor English and Spanish sentences respectively.

I write a code to load my model .meta file and weights

import data_utils
import seq2seq_model
import translate
import tensorflow as tf

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/translate.ckpt-615700.meta')
    saver.restore(sess,tf.train.latest_checkpoint('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/.'))

How can I start retaining for this dataset?

James Z
  • 12,209
  • 10
  • 24
  • 44
Sandeep
  • 369
  • 1
  • 5
  • 16

1 Answers1

0

Save

According to TensorFlow version 2 doc you can use tf.train.Checkpoint and tf.train.CheckpointManager class for saving your Model. Consider the following example:

checkpoint_dir = './training_checkpoints'       # custom directory
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=model)   # your model variable name
manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=checkpoint_dir, max_to_keep=3)           # max_to_keep means how much of last checkpoints number you like to keep

Now if you like to save your model type this: manager.save()

Load

Define checkpoint and checkpointManager again and run this code:

if manager.latest_checkpoint:
    checkpoint.restore((manager.latest_checkpoint)).assert_consumed()
    print("Restored from {}".format(manager.latest_checkpoint))

If you got an error like (AssertionError: Unresolved object in checkpoint (root)) replace assert_consumed with expect_partial. (go here for the difference: link )

The model has loaded from the checkpoint. Now you can load your data and fix the shapes and continue training your model.