Search code examples
parallel-processingpytorchdistributed-system

In the PyTorch Distributed Data Parallel (DDP) tutorial, how does `setup` know it's rank?


For the tutorial Getting Started with Distributed Data Parallel

How does setup() function knows the rank when mp.spawn() doesn't pass the rank?

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    .......

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)
if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_basic, world_size)

Solution

  • mp.spawn does pass the rank to the function it calls.

    From the torch.multiprocessing.spawn docs

    torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')

    ...

    • fn (function) -

      Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as fn(i, *args), where i is the process index and args is the passed through tuple of arguments.

    So when spawn invokes fn it passes it the process index as the first argument.