Is there any way more efficient than this way?
mention_inputs = defaultdict(list)
for idx in mention_indices:
mention_input, _ = ...
for key,value in mention_input.items(): # value is a tensor has shape (dim,)
mention_inputs[key].append(value)
mention_inputs = {key:torch.stack(value) for key, value in mention_inputs.items()}
I think this happens to be very pretty much what the collate_fn
from the dataloader does:
from torch.utils.data.dataloader import default_collate
default_collate(mention_indices)
For a longer explanation: In Pytorch, the Dataset
might return a dictionary for each sample. For simplicity, let's say we have a dictionary with two keys, each of which has a tensor of dimension D
as value. In order to perform efficient batching, the Dataloader first samples many of these dictionaries, obtaining a list of B dictionaries. Then the collate function is in charge to convert the list of dictionaries into a single dictionary. In this example the dictionary would have 2 keys, each of which of dimension (B, D)
.