Search code examples
deep-learningchainer

Increased Global Batch in Data Parallelism size Causes OOM Error


I am increasing the batch size as I increase the number of GPUs when training the AlexNet Model on ImageNet dataset. It works fine up to 4096 when I get OOM errors. I start with a batch size of 1024 on 4 GPUs, then 2048 on 8 GPUs. However, when I attempt 4096 on 16 GPUs, I get OOM. Ideally, this shouldn't happen because, in data parallelism, samples per GPU remain the same. I am using ChainerMN for the training.


Solution

  • Figured this out finally. Dont increase the batch size as you increase the number of GPUs. If you set batch size at say 32, each GPU will have a batch size of 32.