Search code examples
tensorflowkerasdeep-learningpytorchsemantic-segmentation

Inconsistency in loss on SAME data for train and validation modes tensorflow


I'm implementing a semantic segmentation model with images. As a good practice I tested my training pipeline with just one image and tried to over-fit that image. To my surprise, when training with the exactly the same images, the loss goes to near 0 as expected but when evaluating THE SAME IMAGES, the loss is much much higher, and it keeps going up as the training continues. So the segmentation output is garbage when training=False, but when run with training=True is works perfectly.

To be able to anyone to reproduce this I took the official segmentation tutorial and modified it a little for training a convnet from scratch and just 1 image. The model is very simple, just a sequence of Conv2D with batch normalization and Relu. The results are the following

Screenshot from 2020-11-03 18-04-17

As you see the loss and eval_loss are really different, and making inference to the image gives perfect result in training mode and in eval mode is garbage.

I know Batchnormalization behaves differently in inference time since it uses the averaged statistics calculated whilst training. Nonetheless, since we are training with just 1 same image and evaluating in the same image, this shouldn't happen right? Moreover I implemented the same architecture with the same optimizer in Pytorch and this does not happen there. With pytorch it trains and eval_loss converges to train loss

Here you can find the above mentioned https://colab.research.google.com/drive/18LipgAmKVDA86n3ljFW8X0JThVEeFf0a#scrollTo=TWDATghoRczu and at the end also the Pytorch implementation


Solution

  • It had to do more with the defaults values that tensorflow uses. Batchnormalization has a parameter momentum which controls the averaging of batch statistics. The formula is: moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)

    If you set momentum=0.0 in the BatchNorm layer, the averaged statistics should match perfectly with the statistics from the current batch (which is just 1 image). If you do so, you see that the validation loss almost immediately matches the training loss. Also if you try with momentum=0.9 (which is the equivalent default value in pytorch) and it works and converges faster (as in pytorch).