Search code examples
pythondeep-learningpytorchbatch-processingdataloader

How to create batches using PyTorch DataLoader such that each example in a given batch has the same value for an attribute?


Suppose I have a list, datalist which contains several examples (which are of type torch_geometric.data.Data for my use case). Each example has an attribute num_nodes

For demo purpose, such datalist can be created using the following snippet of code

import torch
from torch_geometric.data import Data # each example is of this type
import networkx as nx # for creating random data
import numpy as np
# the python list containing the examples
datalist = []
for num_node in [9, 11]:
    for _ in range(1024):
        edge_index = torch.from_numpy(
            np.array(nx.fast_gnp_random_graph(num_node, 0.5).edges())
        ).t().contiguous()
        datalist.append(
            Data(
                x=torch.rand(num_node, 5), 
                edge_index=edge_index, 
                edge_attr=torch.rand(edge_index.size(1))
            )
        )

From the above datalist object, I can create a torch_geometric.loader.DataLoader (which subclasses torch.utils.data.DataLoader) naively (without any constraints) by using the DataLoader constructor as:

from torch_geometric.loader import DataLoader
dataloader = DataLoader(
    datalist, batch_size=128, shuffle=True
)

My question is, how can I use the DataLoader class to ensure that each example in a given batch has the same value for num_nodes attribute?

PS: I tried to solve it and came up with a hacky solution by combining multiple DataLoader objects using the combine_iterators function snippet from here as follows:

def get_combined_iterator(*iterables):
    nexts = [iter(iterable).__next__ for iterable in iterables]
    while nexts:
        next = random.choice(nexts)
        try:
            yield next()
        except StopIteration:
            nexts.remove(next)

datalists = defaultdict(list)
for data in datalist:
    datalists[data.num_nodes].append(data)
dataloaders = (
    DataLoader(data, batch_size=128, shuffle=True) for data in datalists.values()
)
batches = get_combined_iterator(*dataloaders)

But, I think that there must be some elegant/better method of doing it, hence this question.


Solution

  • If your underlying dataset is map-style, you can use define a torch.utils.data.Sampler which returns the indices of the examples you want to batch together. An instance of this will be passed as a batch_sampler kwarg to your DataLoader and you can remove the batch_size kwarg as the sampler will form batches for you depending on how you implement it.