I would like to use logging in one of the processes managed by Distributed Data Parallel. However, logging print nothing in the following codes (the codes are derived from this tutorial):
#!/usr/bin/python
import os, logging
# logging.basicConfig(level=logging.DEBUG)
import torch
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize the process group.
dist.init_process_group('NCCL', rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
setup(rank, world_size)
if rank == 0:
logger = logging.getLogger('train')
logger.setLevel(logging.DEBUG)
logger.info(f'Running DPP on rank={rank}.')
# Create model and move it to GPU.
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # optimizer takes DDP model.
optimizer.zero_grad()
inputs = torch.randn(20, 10) # .to(rank)
outputs = ddp_model(inputs)
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_func, world_size):
mp.spawn(
demo_func,
args=(world_size,),
nprocs=world_size,
join=True
)
def main():
run_demo(demo_basic, 4)
if __name__ == "__main__":
main()
However, when I uncomment the 4th line, the logging works. May I know the reason and how to fix the bug please?
UPDATE
Let's briefly review how loggers in the logging
module work.
Loggers are organized in a tree structure, i.e. every logger has a unique parent
logger. By default, it will be the root
logger, while the root
logger doesn't have a parent logger.
When you call Logger.info
method (ignore level checking here for simplicity) on a logger, the logger iterates all of its handlers
and let them handle the current record, e.g. handlers can be StreamHandler which can print to stdout, or FileHandler which prints to some file). After all of the handlers of current logger finish their jobs, the record will be given to its parent logger and, the parent logger handles the record in the same way, i.e. iterates all handlers of parent logger and lets them handle the record, finnally passes the record to "grandparent". This procedure continues until reaching the root of current loggers' tree, which doesn't have a parent.
Check the implementation below or here:
def callHandlers(self, record):
c = self
found = 0
while c:
for hdlr in c.handlers:
found = found + 1
if record.levelno >= hdlr.level:
hdlr.handle(record)
if not c.propagate:
c = None #break out
else:
c = c.parent
So in your case, you didn't specify any handler for the train
logger. When you uncomment the 6th line, i.e. by calling logging.basicConfig(level=logging.DEBUG)
, a StreamHandler
is created for the root
logger. Though there isn't any handler for the train
logger, there is a StreamHandler
for its parent i.e. the root
logger, which prints anything you see actually, while train
logger print nothing in this case. When the 6th line in commented, even one StreamHandler
is not created for the root
handler, so in this case nothing is printed. So In fact the issue has nothing to do with DDP.
By the way, the reason I can't reproduce your issue at first is because I use PyTorch 1.8, where logging.info
will be called during the execution of dist.init_process_group
for backends other than MPI, which implicitly calls basicConfig
, creates a StreamHandler
for the root logger and seems to print message as expected.
======================================================================
One possible reason: Because during the execution of dist.init_process_group
, it will call _store_based_barrier
, which finnaly will call logging.info
(see the source code here). So if you call logging.basicConfig
before you call dist.init_process_group
, it will be initialized in advance which makes the root logger ignore all levels of log.
This is not the case in your code because logging.basicConfig
is at the top of the file, which will be executed at frist before dist.init_process_group
. Actually I'm able to run the code you provide, after filling the missing imports such as nn
and dist
though, with logging works normally. Maybe you attempted to reduce the code to reproduce the issue, but circumvent the real issue behind unconsciously? Could you double check if this resolves your issue?