Search code examples
pythonpython-3.xdictionarypytorchtuples

How to convert a tuple of dictionaries of pyTorch tensors into a dictionary of tensors?


I have a tuple of dictionaries, that hold pyTorch tensors:

tuple_of_dicts_of_tensors = (
    {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
    {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
    {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)

Which I would like to transform into a dictionary of tensors:

dict_of_tensors = {
    'key_1': torch.tensor([[1,1,1], [2,2,2], [3,3,3]]),
    'key_2': torch.tensor([[4,4,4], [5,5,5], [6,6,6]])
}

How would you recommend doing that? What is the most efficient way? The tensors are on a GPU device, so a minimal amount of for loops is required.

Thanks!


Solution

  • You can use torch's built-in default_collate() function:

    import torch
    from torch.utils.data import default_collate
    
    tuple_of_dicts_of_tensors = (
        {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
        {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
        {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
    )
    
    dict_of_tensors = default_collate(tuple_of_dicts_of_tensors)
    print(dict_of_tensors)
    
    # >>> {'key_1': tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]),
    #      'key_2': tensor([[4, 4, 4], [5, 5, 5], [6, 6, 6]])}
    

    It is quite a powerful function, although its documentation might not be immediately clear. Shortened quote from the documentation:

    Function that takes in a batch of data and puts the elements within the batch into a tensor with an additional outer dimension - batch size.

    Here is the general input type (based on the type of the element within the batch) to output type mapping:

    • torch.Tensor -> torch.Tensor (with an added outer dimension batch size)
    • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

    In your case, the elements of the batch (i.e. your tuple) are mappings (i.e. your dicts) of tensors. So,

    • in a first step the mappings get "moved to the outside" – meaning you end up with one dict (second bullet from the quoted doc);
    • in a second step, the function is again applied to all the values of your dicts, which are tensors – meaning, for each key, the tensors are collated together into one tensor with a new batch dimension (first bullet from the quoted doc).

    In other words, you can think of the task of default_collate() as moving the batch dimension inward: a batch of A objects containing B objects (in your case: a tuple of dictionary objects containing tensor objects) becomes an A object of batches of B objects (in your case: a dictionary of batches of tensors, where each "batch of tensors" is again a single tensor with a new, prepended batch dimension).