Search code examples
pythonpytorchmultiprocessingconcatenationdistributed

Pytorch Python Distributed Multiprocessing: Gather/Concatenate tensor arrays of different lengths/sizes


If you have tensor arrays of different lengths across several gpu ranks, the default all_gather method does not work as it requires the lengths to be same.

For example, if you have:

if gpu == 0:
    q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
    q = torch.tensor([5.3], device=torch.device(gpu))

If I need to gather these two tensor arrays as follows:

all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])

the default torch.all_gather does not work as the lengths, 2, 1 are different.


Solution

  • As it is not directly possible to gather using built in methods, we need to write custom function with the following steps:

    1. Use dist.all_gather to get sizes of all arrays.
    2. Find the max size.
    3. Pad local array to max size using zeros/constants.
    4. Use dist.all_gather to get all padded arrays.
    5. Unpad the added zeros/constants using sizes found in step 1.

    The below function does this:

    def all_gather(q, ws, device):
        """
        Gathers tensor arrays of different lengths across multiple gpus
        
        Parameters
        ----------
            q : tensor array
            ws : world size
            device : current gpu device
            
        Returns
        -------
            all_q : list of gathered tensor arrays from all the gpus
    
        """
        local_size = torch.tensor(q.size(), device=device)
        all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
        dist.all_gather(all_sizes, local_size)
        max_size = max(all_sizes)
    
        size_diff = max_size.item() - local_size.item()
        if size_diff:
            padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
            q = torch.cat((q, padding))
    
        all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
        dist.all_gather(all_qs_padded, q)
        all_qs = []
        for q, size in zip(all_qs_padded, all_sizes):
            all_qs.append(q[:size])
        return all_qs
    

    Once, we are able to do the above, we can then easily use torch.cat to further concatenate into a single array if needed:

    torch.cat(all_q)
    [torch.tensor([1.5, 2.3, 5.3])
    

    Adapted from: github