Search code examples
pythonpytorch

Return multiple images from a custom dataset


I have a set of bags of frames obtained from YouTube videos and I would like to return an entire bag when I iterate my dataset. My custom dataset class is the following one:

dataset_path = Path('/content/VideoClassificationDataset')

class VideoDataset(Dataset):
  def __init__(self, dictionary, transform = None):
    self.l_dict = list(dictionary.items())
    self.transform = transform
  
  def __len__(self):
    return len(self.l_dict)

  def __get_item__(self, index):
    item = self.l_dict[index]

    images_path = item[0]
    images = [Image.open(f'{dataset_path}/{images_path}/{image}') for image in os.listdir(f'{dataset_path}/{images_path}')]
    
    y_labels = torch.tensor(item[1])

    if self.transform:
      for image in images: self.transform(image)
    
    return images, y_labels

Moreover I’ve done also

def spit_train(train_data, perc_val_size):
  train_size = len(train_data)
  val_size = int((train_size * perc_val_size) // 100)
  train_size -= val_size

  return random_split(train_data, [int(train_size), int(val_size)])

train_data, val_data = spit_train(VideoDataset(train_dict, transform=train_transform()), 20)
test_data = VideoDataset(dictionary=test_dict, transform=test_transform())


BATCH_SIZE = 16
NUM_WORKERS = os.cpu_count()

def generate_dataloaders(train_data, test_data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):

  train_dl = DataLoader(dataset = train_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle = True)

  val_dl = DataLoader(dataset = val_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle = True)

  test_dl = DataLoader(dataset = test_data, 
                              batch_size = BATCH_SIZE, 
                              num_workers = NUM_WORKERS, 
                              shuffle = False) # don't need to shuffle testing data when we are considering time series dataset

  return train_dl, val_dl, test_dl

train_dl, val_dl, test_dl = generate_dataloaders(train_data, test_data)

The train_dict and test_dict are dictionaries that contains the path of each bag of shots as key and the list of labels as value, like so:

{'train/iqGq-8vHEJs/bag_of_shots0': [2],
 'train/iqGq-8vHEJs/bag_of_shots1': [2],
 'train/gnw83R8R6jU/bag_of_shots0': [119],
 'train/gnw83R8R6jU/bag_of_shots1': [119],
...
}

The point is that when I try to see what the dataloader contains:

train_features_batch, train_labels_batch = next(iter(train_dl))
print(train_features_batch.shape, train_labels_batch.shape)

val_features_batch, val_labels_batch = next(iter(val_dl))
print(val_features_batch.shape, val_labels_batch.shape)

I get:

NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py", line 295, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py", line 53, in __getitem__
    raise NotImplementedError
NotImplementedError

I’m not particularly sure If I can return a set of images in my get_item() function at this point.


Solution

  • There is a typo in the function name, instead of __get_item__, the name should be __getitem__.

    Since this is not defined in your custom dataset class, the function from the base class (torch.utils.data.Dataset) is used, which doesn't implement this since it needs to be implemented by each dataset that inherits from this class. So you get an NotImplementedError.

    More documentation on this can be found here.