Search code examples
python-3.xpytorchpython-multiprocessingpython-logging

Pytorch: why logging fails in DDP?


I would like to use logging in one of the processes managed by Distributed Data Parallel. However, logging print nothing in the following codes (the codes are derived from this tutorial):

#!/usr/bin/python

import os, logging
# logging.basicConfig(level=logging.DEBUG)

import torch

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    setup(rank, world_size)

    if rank == 0:
        logger = logging.getLogger('train')
        logger.setLevel(logging.DEBUG)
        logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    outputs = ddp_model(inputs)

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size):
    mp.spawn(
        demo_func,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )


def main():
    run_demo(demo_basic, 4)


if __name__ == "__main__":
    main()

However, when I uncomment the 4th line, the logging works. May I know the reason and how to fix the bug please?


Solution

  • UPDATE

    Let's briefly review how loggers in the logging module work.

    Loggers are organized in a tree structure, i.e. every logger has a unique parent logger. By default, it will be the root logger, while the root logger doesn't have a parent logger.

    When you call Logger.info method (ignore level checking here for simplicity) on a logger, the logger iterates all of its handlers and let them handle the current record, e.g. handlers can be StreamHandler which can print to stdout, or FileHandler which prints to some file). After all of the handlers of current logger finish their jobs, the record will be given to its parent logger and, the parent logger handles the record in the same way, i.e. iterates all handlers of parent logger and lets them handle the record, finnally passes the record to "grandparent". This procedure continues until reaching the root of current loggers' tree, which doesn't have a parent.

    Check the implementation below or here:

    def callHandlers(self, record):
        c = self
        found = 0
        while c:
            for hdlr in c.handlers:
                found = found + 1
                if record.levelno >= hdlr.level:
                    hdlr.handle(record)
            if not c.propagate:
                c = None    #break out
            else:
                c = c.parent
    

    So in your case, you didn't specify any handler for the train logger. When you uncomment the 6th line, i.e. by calling logging.basicConfig(level=logging.DEBUG), a StreamHandler is created for the root logger. Though there isn't any handler for the train logger, there is a StreamHandler for its parent i.e. the root logger, which prints anything you see actually, while train logger print nothing in this case. When the 6th line in commented, even one StreamHandler is not created for the root handler, so in this case nothing is printed. So In fact the issue has nothing to do with DDP.

    By the way, the reason I can't reproduce your issue at first is because I use PyTorch 1.8, where logging.info will be called during the execution of dist.init_process_group for backends other than MPI, which implicitly calls basicConfig, creates a StreamHandler for the root logger and seems to print message as expected.

    ======================================================================

    One possible reason: Because during the execution of dist.init_process_group, it will call _store_based_barrier, which finnaly will call logging.info (see the source code here). So if you call logging.basicConfig before you call dist.init_process_group, it will be initialized in advance which makes the root logger ignore all levels of log.

    This is not the case in your code because logging.basicConfig is at the top of the file, which will be executed at frist before dist.init_process_group. Actually I'm able to run the code you provide, after filling the missing imports such as nn and dist though, with logging works normally. Maybe you attempted to reduce the code to reproduce the issue, but circumvent the real issue behind unconsciously? Could you double check if this resolves your issue?