0

For the main.py of the px2graph project, the part of training and validation is shown as below:

splits = [s for s in ['train', 'valid'] if opt.iters[s] > 0]
start_round = opt.last_round - opt.num_rounds

# Main training loop
for round_idx in range(start_round, opt.last_round):
    for split in splits:

        print("Round %d: %s" % (round_idx, split))
        loader.start_epoch(sess, split, train_flag, opt.iters[split] * opt.batchsize)

        flag_val = split == 'train'

        for step in tqdm(range(opt.iters[split]), ascii=True):
            global_step = step + round_idx * opt.iters[split]
            to_run = [sample_idx, summaries[split], loss, accuracy]
            if split == 'train': to_run += [optim]

            # Do image summaries at the end of each round
            do_image_summary = step == opt.iters[split] - 1
            if do_image_summary: to_run[1] = image_summaries[split]

            # Start with lower learning rate to prevent early divergence
            t = 1/(1+np.exp(-(global_step-5000)/1000))
            lr_start = opt.learning_rate / 15
            lr_end = opt.learning_rate
            tmp_lr = (1-t) * lr_start + t * lr_end

            # Run computation graph
            result = sess.run(to_run, feed_dict={train_flag:flag_val, lr:tmp_lr})

            out_loss = result[2]
            out_accuracy = result[3]
            if sum(out_loss) > 1e5:
                print("Loss diverging...exiting before code freezes due to NaN values.")
                print("If this continues you may need to try a lower learning rate, a")
                print("different optimizer, or a larger batch size.")
                return

            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, global_step, out_loss, out_accuracy))

            # Log data
            if split == 'valid' or (split == 'train' and step % 20 == 0) or do_image_summary:
                writer.add_summary(result[1], global_step)
                writer.flush()

    # Save training snapshot
    saver.save(sess, 'exp/' + opt.exp_id + '/snapshot')
    with open('exp/' + opt.exp_id + '/last_round', 'w') as f:
        f.write('%d\n' % round_idx)

It seems that the author only get the result of each batch of the validation set. I am wondering, if I want to observe whether the model is improving or reaching the best performance, should I use the result on the whole validation set?

Panfeng Li
  • 3,321
  • 3
  • 26
  • 34

1 Answers1

0

If the validation set is small enough, we could calculate the loss, accuracy on the whole validation set during training to observe the performance. However, if the validation set is too large, it is better to calculate batch-wise validation results and for multiple steps.

Panfeng Li
  • 3,321
  • 3
  • 26
  • 34