Search code examples
pythonimagemachine-learningpytorchcomputer-vision

Skip bad data points when loading data using DataLoader


I am trying to perform an image classification task using mini-imagenet dataset. The data that I want to use, contains a few bad data points(I am not sure why). I would like to load this data and train my model on it. In the process, I want to skip the bad data points completely. How do I do this? The data loader I am using is as follows:

class MiniImageNet(Dataset):

    def __init__(self, root, train=True,
                transform=None,
                index_path=None, index=None, base_sess=None):
        if train:
            setname = 'train'
        else:
            setname = 'test'
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train  # training set or test set
        self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images')
        self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split')

        csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        self.data = []
        self.targets = []
        self.data2label = {}
        lb = -1

        self.wnids = []

        for l in lines:
            name, wnid = l.split(',')
            path = osp.join(self.IMAGE_PATH, name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            self.data.append(path)
            self.targets.append(lb)
            self.data2label[path] = lb

        self.y = self.targets

        if train:
            image_size = 84
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])])

        else:
            image_size = 84
            self.transform = transforms.Compose([
                transforms.Resize([image_size, image_size]),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])])


    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, targets = self.data[i], self.targets[i]
        image = self.transform(Image.open(path).convert('RGB'))
        return image, targets

I tried to use a try-except sequence, but in that case, instead of skipping, the dataloader is returning None, causing an error. How do I completely skip a datapoint in a dataloader?


Solution

  • Try removing the bad data at the end of the __init__ function.

    for i in range(len(self.data) - 1, -1, -1):
        if is_bad_data(self.data[i], self.targets[i]):
            del self.data[i]
            del self.targets[i]