Search code examples
pythontensorflowimage-processingpytorchautoencoder

Why is Normalization causing my network to have exploding gradients in training?


I've built a network (In Pytorch) that performs well for image restoration purposes. I'm using an autoencoder with a Resnet50 encoder backbone, however, I am only using a batch size of 1. I'm experimenting with some frequency domain stuff that only allows me to process one image at a time.

I have found that my network performs reasonably well, however, it only behaves well if I remove all batch normalization from the network. Now of course batch norm is useless for a batch size of 1 so I switched over to group norm, designed for this purpose. However, even with group norm, my gradient explodes. The training can go very well for 20 - 100 epochs and then game over. Sometimes it recovers and explodes again.

I should also say that in training, every new image fed in is given a wildly different amount of noise to train for random noise amounts. This has been done before but perhaps coupled with a batch size of 1 it could be problematic.

I'm scratching my head at this one and I'm wondering if anyone has suggestions. I've dialed in my learning rate and clipped the max gradients but this isn't really solving the actual issue. I can post some code but I'm not sure where to start and hoping someone could give me a theory. Any ideas? Thanks!


Solution

  • To answer my own question, my network was unstable in training because a batch size of 1 makes the data too different from batch to batch. Or as the papers like to put it, too high an internal covariate shift.

    Not only were my images drawn from a very large varied dataset, but they were also rotated and flipped randomly. As well as this, random Gaussain of noise between 0 and 30 was chosen for each image, so one image may have little to no noise while the next may be barely distinguisable in some cases. Or as the papers like to put it, too high an internal covariate shift.

    In the above question I mentioned group norm - my network is complex and some of the code is adapted from other work. There were still batch norm functions hidden in my code that I missed. I removed them. I'm still not sure why BN made things worse.

    Following this I reimplemented group norm with groups of size=32 and things are training much more nicely now.

    In short removing the extra BN and adding Group norm helped.