pythondistributedtorch

distributed torch data collision from all_gather (writing all_gather results to file "fixes" the problem)


Problem:

  • a distributed process computes errors and returns them alongside float indices
  • when errors are gathered from the separate ranks, there's collision happening on these indices
    • so if the dataset has 100 samples and the number of GPUs is 4, the resulting set of indices will be of length 25 instead of the expected 100
  • when I write each rank's data (pre-gather) to file, I can verify that the indices are 100% disjoint
  • when I write each rank's data (post-gather) to file, the issue disappears
  • comment out the post-gather debug data file writing, the issue returns

NOTE: printing out the post-gather results also "fixes" the issue, but sorting the post-gather results does not.

So something about writing the post-gather data to file is resolving some distributed shenanigans...I'm reminded of the need to flush streams to avoid unexpected results, but I don't see any kind of corollary in the documentation.

Here's a minimal example that shows what's going on in my code:

# setup_distributed_stuff()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

# Data returned from distributed computation.
# Note that there's no overlap between the different ranks.
data = torch.arange(
    0 + (rank * 100 // world_size),
    (rank + 1) * 100 // world_size,
)

# `data` is confirmed to be disjoint across ranks by writing to file here.

# Gather data from all ranks.
if world_size > 1:
    all_data = [torch.zeros_like(data) for _ in range(world_size)]
    torch.distributed.all_gather(all_data, data)
    data = torch.cat(all_data, dim=0)

    # By writing "data" to file for debugging, the problem goes away...
    #     i.e. len(set(data.numpy())) == 100!
    # If I comment this out, then my gathered data collides...
    #     i.e. len(set(data.numpy())) == 100 // world_size
    with open("debug_data.pt", "wb") as _file:
        torch.save(data, _file)

    # I can also simply print the indices and get the same effect...
    logger.info(
        "Gathered result indices: {}...{}".format(
            data[:10, -1], data[-10:, -1]
        )
    )

    # However, sorting the indices doesn't do me any good...
    data = data[data[:, -1].argsort(dim=0)]


if rank == 0:
    # do_something(data)

Solution

  • Adding a torch.distributed.barrier() call after the all_gather() call fixes the issue in a more satisfying manner. I didn't think to do this because the docs state that all_gather() is a blocking call. Perhaps they mean blocking as in not async; distinct from torch.distributed.

    I suppose the reason logging and writing the results to file "fix" the issue while sort does not is because the former are not torch ops (and, therefore, not managed by the distributed process group) which forces synchronization.