Search code examples
pytorchdistribute

pytorch - How to Save and load model from DistributedDataParallel learning


I'm new to the Pytorch DstributedDataParallel(), but I found that most of the tutorials save the local rank 0 model during training. Which means if I get 3 machine with 4 GPU on each of them, at the final I'll get 3 model that save from each machine.

For example in pytorch ImageNet tutorial on line 252:

if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({...})

They save the model if rank % ngpus_per_node == 0.

To the best of my knowledge, DistributedDataParallel() will automatic do all reduce to the loss on the backend, without doing any further job, every process can sync the loss automatically base on that. All the model on each process will only have slightly different at the end of the process. That mean we only need to save one model is enough.

So why don't we just save model on rank == 0, but rank % ngpus_per_node == 0 ?

And which model should I used for if I get multiple model?

If this is the right way for saving model in distribute learning, should I merge them, used one of them, or inference the result base on all three models?

Please let me know if I'm wrong.


Solution

  • What is going on

    Please correct me if I'm wrong in any place

    The changes you are referring to were introduced in 2018 via this commit and described as:

    in multiprocessing mode, only one process will write the checkpoint

    Previously, those were saved without any if block so each node on each GPU would save a model which is indeed wasteful and would most probably overwrite saved model multiple times on each node.

    Now, we are talking about multiprocessing distributed (possibly many workers each with possibly multiple GPUs).

    args.rank for each process is thus modified inside the script by this line:

    args.rank = args.rank * ngpus_per_node + gpu
    

    which has the following comment:

    For multiprocessing distributed training, rank needs to be the global rank among all the processes

    Hence args.rank is unique ID amongst all GPUs amongst all nodes (or so it seems).

    If so, and each node has ngpus_per_node (in this training code it is assumed each has the same amount of GPUs from what I've gathered), then the model is saved only for one (last) GPU on each node. In your example with 3 machines and 4 GPUs you would get 3 saved models (hopefully I understand this code correctly as it's pretty convoluted tbh).

    If you used rank==0 only one model per world (where world would be defined as n_gpus * n_nodes) would be saved.

    Questions

    First question

    So why don't we just save model on rank == 0, but rank % ngpus_per_node == 0 ?

    I will start with your assumption, namely:

    To the best of my knowledge, DistributedDataParallel() will automatic do all reduce to the loss on the backend, without doing any further job, every process can sync the loss automatically base on that.

    Precisely, it has nothing to do with loss but rather gradient accumulation and applied corrections to weights, as per documentation (emphasis mine):

    This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and each device, and each such replica handles a portion of the input. During the backwards pass, gradients from each node are averaged.

    So, when the model is created with some weights it is replicated on all devices (each GPU for each node). Now each GPU gets a part of input (say, for total batch size equal to 1024, 4 nodes each with 4 GPUs, each GPU would get 64 elements), calculates forward pass, loss, performs backprop via .backward() tensor method. Now all gradients are averaged by all-gather, parameters are optimized on root machine and parameters are distributed to all nodes so module's state is always the same across all machines.

    Note: I'm not sure how this averaging exactly takes place (and I don't see it explicitly said in docs), though I assume those are first averaged across GPUs and later across all nodes as it would be the most efficient I think.

    Now, why would you save model for each node in such case? In principle you could only save one (as all modules will be exactly the same), but it has some downsides:

    • Say your node where your model was saved crashes and the file is lost. You have to redo all the stuff. Saving each model is not too costly operation (done once per epoch or less) so it can be easily done for each node/worker
    • You have to restart training. This means model would have to be copied to each worker (and some necessary metadata, though I don't think it's the case here)
    • Nodes will have to wait for every forward pass to finish anyway (so the gradients can be averaged), if the model saving takes a lot of time it would waste GPU/CPU being idle (or some other synchronization scheme would have to be applied, I don't think there is one in PyTorch). This makes it somewhat "no-cost" if you look at the overall picture.

    Question 2 (and 3)

    And which model should I used for if I get multiple model?

    It doesn't matter as all of them will be exactly the same as the same corrections via optimizer are applied to the model with the same initial weights.

    You could use something along those lines to load your saved .pth model:

    import torch
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    parallel_model = torch.nn.DataParallel(MyModelGoesHere())
    parallel_model.load_state_dict(
        torch.load("my_saved_model_state_dict.pth", map_location=str(device))
    )
    
    # DataParallel has model as an attribute
    usable_model = parallel_model.model