The only way it would seem to work (logically) is if the model was loaded in each of the GPUs. This would mean that when weights are updated, each GPU would need to update the weights as well, increasing the workload compared to a single GPU. Is this line of reasoning correct?
First of all, it is advised to use torch.nn.parallel.DistributedDataParallel
instead.
You can check torch.nn.DataParallel
documentation where the process is described (you can also check source code and dig a little deeper on github, here is how replication of module is performed).
Here is roughly how it's done:
All (or chosen) devices ids are saved in constructor and dimension along which data will be scattered (almost always 0
meaning it will be splitted to devices along batch)
This is done during every forward
run:
tuple
, list
, dict
shallowed copied, other data is shared among threads).module(*args, **kwargs)
Source machine is the cuda:0
by default, but it can be chosen. Also weights are updated on a single device
, only batch is scattered and the outputs gathered.