I would like to stop training a network once I see the error calculated from the validation set starts to increase. I'm using a BasicNetwork with RPROP as the training algorithm, and I have the following training iteration:
void trainCrossValidation(BasicNetwork network, MLDataSet training, MLDataSet validation) {
FoldedDataSet folded = new FoldedDataSet(training);
Train train = new ResilientPropagation(network, folded);
CrossValidationKFold trainFolded = new CrossValidationKFold(train, KFOLDS);
trainFolded.addStrategy(new SimpleEarlyStoppingStrategy(validation));
int epoch = 1;
do {
trainFolded.iteration();
logger.debug("Iter. " + epoch + ": Erro = " + trainFolded.getError());
epoch++;
} while (!trainFolded.isTrainingDone() && epoch < MAX_ITERATIONS);
trainFolded.finishTraining();
}
Unfortunately it is not working as expected. The method takes a huge time to execute and seems not to stop at the right moment. I wish the training be aborted at the exactly instant that the validation error begins to grow, that is, in the ideal amount of training iterations.
Is there a way that extract the validation data directly from a cross-validation folded instead of creating an another MLDataSet exclusively for validation? If yes, how to do this?
Which parameter should I use to stop the training? Can you show me the necessary modifications to do what is expected? How could I use cross-validation and SimpleEarlyStoppingStrategy together? I'm pretty confused
Thank you so much for any assistance.