Search code examples
pytorchgpumulti-gpu

Does PyTorch's nn.DataParallel load the same model in each GPU?


The only way it would seem to work (logically) is if the model was loaded in each of the GPUs. This would mean that when weights are updated, each GPU would need to update the weights as well, increasing the workload compared to a single GPU. Is this line of reasoning correct?


Solution

  • First of all, it is advised to use torch.nn.parallel.DistributedDataParallel instead.

    You can check torch.nn.DataParallel documentation where the process is described (you can also check source code and dig a little deeper on github, here is how replication of module is performed).

    Here is roughly how it's done:

    Initialization

    All (or chosen) devices ids are saved in constructor and dimension along which data will be scattered (almost always 0 meaning it will be splitted to devices along batch)

    Forward

    This is done during every forward run:

    1. Inputs are scattered (tensors along dimensions, tuple, list, dict shallowed copied, other data is shared among threads).
    2. If there is only one device just return module(*args, **kwargs)
    3. If there are multiple devices, copy the network from source machine to other devices (it is done each time!)
    4. Run forward on each device with it's respective input
    5. Gather outputs from devices onto a single source device (concatenation of outputs) onto a source machine.
    6. Run the rest of the code, backprop, update weights on source machine etc.

    Source machine is the cuda:0 by default, but it can be chosen. Also weights are updated on a single device, only batch is scattered and the outputs gathered.