Search code examples
pytorchdatasetmultiple-databasesdataloader

How to load data from multiply datasets in pytorch


I have two datasets of images - indoors and outdoors, they don't have the same number of examples.

Each dataset has images that contain a certain number of classes (minimum 1 maximum 4), these classes can appear in both datasets, and each class has 4 categories - red, blue, green, white. Example: Indoor - cats, dogs, horses Outdoor - dogs, humans

I am trying to train a model, where I tell it, "here is an image that contains a cat, tell me it's color" regardless of where it was taken (Indoors, outdoors, In a car, on the moon)

To do that, I need to present my model examples so that every batch has only one category (cat, dog, horse or human), but I want to sample from all datasets (two in this case) that contains these objects and mix them. How can I do this?

It has to take into account that the number of examples in each dataset is different, and that some categories appear in one dataset where others can appear in more than one. and each batch must contain only one category.

I would appreciate any help, I have been trying to solve this for a few days now.


Solution

  • Assuming the question is:

    1. Combine 2+ data sets with potentially overlapping categories of objects (distinguishable by label)
    2. Each object has 4 "subcategories" for each color (distinguishable by label)
    3. Each batch should only contain a single object category

    The first step will be to ensure consistency of the object labels from both data sets, if not already consistent. For example, if the dog class is label 0 in the first data set but label 2 in the second data set, then we need to make sure the two dog categories are correctly merged. We can do this "translation" with a simple data set wrapper:

    class TranslatedDataset(Dataset):
      """
      Args:
        dataset: The original dataset.
        translate_label: A lambda (function) that maps the original
          dataset label to the label it should have in the combined data set
      """
      def __init__(self, dataset, translate_label):
        super().__init__()
        self._dataset = dataset
        self._translate_label = translate_label
    
      def __len__(self):
        return len(self._dataset)
    
      def __getitem__(self, idx):
        inputs, target = self._dataset[idx]
        return inputs, self._translate_label(target)
    

    The next step is combining the translated data sets together, which can be done easily with a ConcatDataset:

    first_original_dataset = ...
    second_original_dataset = ...
    
    first_translated = TranslateDataset(
      first_original_dataset, 
      lambda y: 0 if y is 2 else 2 if y is 0 else y, # or similar
    )
    second_translated = TranslateDataset(
      second_original_dataset, 
      lambda y: y, # or similar
    )
    
    combined = ConcatDataset([first_translated, second_translated])
    

    Finally, we need to restrict batch sampling to the same class, which is possible with a custom Sampler when creating the data loader.

    class SingleClassSampler(torch.utils.data.Sampler):
      def __init__(self, dataset, batch_size):
        super().__init__()
        # We need to create sequential groups
        # with batch_size elements from the same class
        indices_for_target = {} # dict to store a list of indices for each target
        
        for i, (_, target) in enumerate(dataset):
          # converting to string since Tensors hash by reference, not value
          str_targ = str(target)
          if str_targ not in indices_for_target:
            indices_for_target[str_targ] = []
          indices_for_target[str_targ] += [i]
    
        # make sure we have a whole number of batches for each class
        trimmed = { 
          k: v[:-(len(v) % batch_size)] 
          for k, v in indices_for_target.items()
        }
    
        # concatenate the lists of indices for each class
        self._indices = sum(list(trimmed.values()))
      
      def __len__(self):
        return len(self._indices)
    
      def __iter__(self):
        yield from self._indices
    

    Then to use the sampler:

    loader = DataLoader(
      combined, 
      sampler=SingleClassSampler(combined, 64), 
      batch_size=64, 
      shuffle=True
    )
    

    I haven't run this code, so it might not be exactly right, but hopefully it will put you on the right track.


    torch.utils.data Docs