Search code examples
pythonpytorchdatasetdataloader

KeyError when iterating through pytorch dataloader


I am trying to build a model with pytorch, and I want to use a customized dataset. So, I have a dataset.py which defines a class, MyData, which is a subclass of torch.utils.data.Dataset. Here's the file.

# dataset.py
import torch
from tqdm import tqdm
import numpy as np
import re
from torch.utils.data import Dataset
from pathlib import Path


class MyDataset(Dataset):
    def __init__(self, path, size=10000):
        if not Path(path).exists():
            raise FileNotFoundError
            
        self.data = []
        self.load_data(path, size)


    def __len__(self):
        return len(self.data)


    def __getitem__(self, index):
        return self.data[index]


    def load_data(self, path, size):
        # Loading data from csv files and some preparation
        # Each sample is in the format of (int_tag1, int_tag2, feature_dictionary),
        # then the sample is appended to self.data
        pass

Then I tried to test this dataset using a DataLoader in the test file dataset_test.py

from torch.utils.data import DataLoader
from dataset import MyDataset


path = 'dataset/sample_train.csv'
size = 1000
dataset = MyDataset(path, size)

dataloader = DataLoader(dataset, batch_size=1000)
for v in dataloader:
    print(v)

I got the following output

730600it [11:08, 1093.11it/s]
1000it [00:00, 20325.47it/s]
Traceback (most recent call last):
  File "dataset_test.py", line 12, in <module>
    for v in dataloader:
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <listcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
KeyError: '210'

The first two lines might be the output when loading data. (I'm not sure because I didn't write any output. But I am using tqdm to load data, so I assume it's tqdm's output?)

Then, I got this key error. I'm wondering which part should be modified? I think the dataset class is well-written, since there's no error when reading the data from file. Is it because the format of samples is not right, so the dataloader cannot load data from dataset properly? Is there any requirement for the format? I've read other people's code, but I didn't find any info mentioning that there's any requirement of the format of samples in Dataset class.

EDIT: A single sample looks like this

('0', '0', {'210': '9093445', '216': '9154780', '301': '9351665', '205': '4186222', '206': '8316799', '207': '8416205', '508': '9355039', '121': '3438658', '122': '3438762', '101': '31390', '124': '3438769', '125': '3438774', '127': '3438782', '128': '3864885', '129': '3864887', '150_14': '3941161', '127_14': '3812616', '109_14': '449068', '110_14': '569621'})

The first two '0's are labels, and the following dictionary contains features.


Solution

  • As @Shai mentioned, if they keys in feature_dictionary are not the same in a batch, then you get this error from the default collate_fn of DataLoader. As a solution, you can write a custom collate_fn as follows and it works

    class MyDataset(Dataset):
        # ... your code ...
    
        def collate_fn(self, batch):
            tag1_batch = []
            tag2_batch = []
            feat_dict_batch = []
            for tag1, tag2, feat_dict in batch:
                tag1_batch.append(tag1)
                tag2_batch.append(tag2)
                feat_dict_batch.append(feat_dict)
            
            return tag1_batch, tag2_batch, feat_dict_batch
    
    path = 'dataset/sample_train.csv'
    size = 1000
    dataset = MyDataset(path, size)
    
    dataloader = DataLoader(dataset, batch_size=3, collate_fn=dataset.collate_fn)