Search code examples
pytorchconv-neural-networkdataloader

how to load one type of image in cifar10 or stl10 with pytorch


This is a very simple question, I'm just trying to select a specific class of images (eg "car") from a standard pytorch image dataset. At the moment the data loader looks like this:

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])),
shuffle=True, batch_size=8)
train_iterator = iter(cycle(train_loader))
class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']

train_iterator = iter(cycle(train_loader))

The iterator returns a batch of shuffled images of all types, but I would like to be able to select what types of images are returned, eg. just images of deer, or ships


Solution

  • Done it!

    def cycle(iterable):
        while True:
            for x in iterable:
                yield x
    
    # Return only images of certain class (eg. aeroplanes = class 0)
    def get_same_index(target, label):
        label_indices = []
        for i in range(len(target)):
            if target[i] == label:
                label_indices.append(i)
        return label_indices
    
    # STL10 dataset
    train_dataset = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor()]))
    
    label_class = 1# birds
    
    # Get indices of label_class
    train_indices = get_same_index(train_dataset.labels, label_class)
    
    bird_set = torch.utils.data.Subset(train_dataset, train_indices)
    
    train_loader = torch.utils.data.DataLoader(dataset=bird_set, shuffle=True,
                                               batch_size=batch_size, drop_last=True)
    train_iterator = iter(cycle(train_loader))