0

To avoid overfitting it's necessary, after the X steps, to hold the training and validate its result. If the curve (iterations x loss) for validation crosses the curve (iterations x loss) for training I need to stop the train.

How can I validate the training result to avoid overfitting?

def train(self, dataset):
    num_samples = len(dataset)
    print('Training...')
    tic = time.time()
    with tf.compat.v1.Session() as sess:
        # start a tensorflow session and initialize all variables
        sess.run(tf.compat.v1.global_variables_initializer())
        for i in range(self.epoch): # iterate through the number of cycles=
            for j in range(num_samples): # one-by-one train the neural network on a data item
                loss, _ = sess.run([self.loss, self.train_op], feed_dict={self.x:[dataset[j]]})

            if i % 10 == 0:
                ram_train.append(cpu_usage(1))
                print(f'epoch {i}: loss = {loss}')
                self.saver.save(sess, f'./model_hidden{self.hidden}_wdw{self.window}.ckpt')
            self.saver.save(sess, f'./model_hidden{self.hidden}_wdw{self.window}.ckpt')
    tac = time.time()
    print('Done.')
    return loss, ram_train, (tac - tic)

I created a class named Autoencoder and one of its methods is to train the ANN. This code is running, but the output is overfitted. I googled it and checked the TensorFlow session documentation looking for any parameter that I can include in my code but with no success.

  • Hi, not sure that the thumb rule:"If the curve (iterations x loss) for validation crosses the curve (iterations x loss) for training I need to stop the train" is a robust enough rule to use. Anyway, some explaination of how it is implemented would make your question more clear. It is not clear how the class named Autoencoder is related to your question. Also I would recommend to use a more modern version of tensorflow - especially for the purpose of learning. – Nir Nov 03 '22 at 13:46
  • @Nir, sorry. I meant at each epoch, the model computes the loss of both the training and validation sets. If the validation loss begins to increase, stop your training. I'm using the last version of TensorFlow. I just need to train an ANN model and during the training check the validation loss to stop the train before overfitting the model. – Mariana Flávio Nov 03 '22 at 15:07

1 Answers1

0

you can use the keras callback early stopping, documentation is [here.][1] set the callback as suggested below:

es=tf.keras.callbacks.EarlyStopping( monitor='val_loss', min_delta=0, patience=3,
    verbose=1,   mode='auto',  baseline=None,    restore_best_weights=True)

in model.fit set callbacks=[es]

Early stopping with restore_best_weight=True will return your model with the weights set to those of the epoch with the lowest loss. This is one way to prevent overfitting but it is NOT the best way to get your model to have the lowest validation loss. This does not "Prevent" overfitting, it "detects" overfitting and returns your model weights from the epoch before over fitting initiated. It is better to prevent over fitting to begin with. There are several ways to do that. One is to add dropout layers to your model. Another is to add regularizers. For example if you have dense layers in your model you can incorporate regularizers as shown below

x = Dense(256, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006), bias_regularizer=regularizers.l1(0.006) ,activation='relu')(x)

One of the best ways to get to lower validation loss is to use an adjustable learning rate. This can easily be accomplished using the keras callback ReduceLROnPlateau, documentation is [here.][2] my recommended code for this callback is shown below.

rlronp=tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.4, patience=2, verbose=1, mode="auto", min_delta=0.0001, cooldown=0, min_lr=0)

Then in model.fit set callbacks=[es, rlronp]



  [1]: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping
  [2]: https://keras.io/api/callbacks/reduce_lr_on_plateau/
Gerry P
  • 7,662
  • 3
  • 10
  • 20