2

I'm implementing a semantic segmentation model with images. As a good practice I tested my training pipeline with just one image and tried to over-fit that image. To my surprise, when training with the exactly the same images, the loss goes to near 0 as expected but when evaluating THE SAME IMAGES, the loss is much much higher, and it keeps going up as the training continues. So the segmentation output is garbage when training=False, but when run with training=True is works perfectly.

To be able to anyone to reproduce this I took the official segmentation tutorial and modified it a little for training a convnet from scratch and just 1 image. The model is very simple, just a sequence of Conv2D with batch normalization and Relu. The results are the following

Screenshot from 2020-11-03 18-04-17

As you see the loss and eval_loss are really different, and making inference to the image gives perfect result in training mode and in eval mode is garbage.

I know Batchnormalization behaves differently in inference time since it uses the averaged statistics calculated whilst training. Nonetheless, since we are training with just 1 same image and evaluating in the same image, this shouldn't happen right? Moreover I implemented the same architecture with the same optimizer in Pytorch and this does not happen there. With pytorch it trains and eval_loss converges to train loss

Here you can find the above mentioned https://colab.research.google.com/drive/18LipgAmKVDA86n3ljFW8X0JThVEeFf0a#scrollTo=TWDATghoRczu and at the end also the Pytorch implementation

charlie
  • 145
  • 2
  • 7

2 Answers2

0

It had to do more with the defaults values that tensorflow uses. Batchnormalization has a parameter momentum which controls the averaging of batch statistics. The formula is: moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)

If you set momentum=0.0 in the BatchNorm layer, the averaged statistics should match perfectly with the statistics from the current batch (which is just 1 image). If you do so, you see that the validation loss almost immediately matches the training loss. Also if you try with momentum=0.9 (which is the equivalent default value in pytorch) and it works and converges faster (as in pytorch).

charlie
  • 145
  • 2
  • 7
0

It's old, but I had the same experience. I made a typo and got a model that could overfit itself on a dataset that couldn't possibly fit in its memory. consistent near 100% guess in training mode, gibberish in eval on the same data. I dug into it, just out of curiosity, and found out that model learned to "Intentionally" shift outputs just ever so slightly, sacrificing a small percentage in loss, but by doing that producing a gradient that was used as an extra memory. It was so hellbent on overfitting, apparently, that it went outside itself for that. I don't know what to do with it yet, it looks like it is more of a feature than a bug.

Bo Ba
  • 1