Search code examples
pythonpytorchtensor

What is an efficient way for merge list of same key dictionaries which value is tensor [Pytorch]


Is there any way more efficient than this way?

mention_inputs = defaultdict(list)

        for idx in mention_indices:
            mention_input, _ = ...
            for key,value in mention_input.items(): # value is a tensor has shape (dim,)
                mention_inputs[key].append(value)
        
        mention_inputs = {key:torch.stack(value) for key, value in mention_inputs.items()}

Solution

  • I think this happens to be very pretty much what the collate_fn from the dataloader does:

    from torch.utils.data.dataloader import default_collate
    
    default_collate(mention_indices)
    

    For a longer explanation: In Pytorch, the Dataset might return a dictionary for each sample. For simplicity, let's say we have a dictionary with two keys, each of which has a tensor of dimension D as value. In order to perform efficient batching, the Dataloader first samples many of these dictionaries, obtaining a list of B dictionaries. Then the collate function is in charge to convert the list of dictionaries into a single dictionary. In this example the dictionary would have 2 keys, each of which of dimension (B, D).