Search code examples
machine-learningdaskdistributed-computingpytorch-lightningddp

What is the simplest way to train pytroch-lightning model over a bunch of servers with Dask?


I have access to a couple dozens Dask servers without GPU but with complete control of the software (can wipe them and install something different) and want to accelerate pytorch-lightning model training. What could be a possible solution to integrate them with as little additional code possible?

I've researched this topic a bit, finding possible options, cannot determine which one to choose:

# option info  pro  con 
1. dask-pytorch-ddp Package to be used to writing models with easier integration into Dask  will likely work cannot use existing model out of the box, need rewriting the model itself
2. PL docs, on-prem cluster (intermediate) multiple copies of pytorch lighning on the network  simples way according to lightning docs fiddly to launch according to the docs
3. PL docs, SLURM cluster wipe/redeploy cluster, setup SLURM  less fiddly to launch individual jobs need to redeploy the cluster OS/software
4. Pytorch + dask officially supported and documented use of Skorch  has a package handling this - skorch will need to use pytorch, not lightning

Are there any more options or tutorials to learn about this?


Solution

  • I'd recommend looking into Horovod for this. Horovod is a distributed deep learning training framework for TensorFlow, Keras, PyTorch, and Apache MXNet. You can use the Horovod integration with PyTorch Lightning to distribute the training of your model. This approach will require installing Horovod on your servers and minimal changes to your existing PyTorch Lightning code.

    As an alternative, you can also consider using Ray for distributed training with PyTorch Lightning.