This is a very simple question, I'm just trying to select a specific class of images (eg "car") from a standard pytorch image dataset. At the moment the data loader looks like this:
def cycle(iterable):
while True:
for x in iterable:
yield x
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])),
shuffle=True, batch_size=8)
train_iterator = iter(cycle(train_loader))
class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
train_iterator = iter(cycle(train_loader))
The iterator returns a batch of shuffled images of all types, but I would like to be able to select what types of images are returned, eg. just images of deer, or ships
Done it!
def cycle(iterable):
while True:
for x in iterable:
yield x
# Return only images of certain class (eg. aeroplanes = class 0)
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
# STL10 dataset
train_dataset = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()]))
label_class = 1# birds
# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)
bird_set = torch.utils.data.Subset(train_dataset, train_indices)
train_loader = torch.utils.data.DataLoader(dataset=bird_set, shuffle=True,
batch_size=batch_size, drop_last=True)
train_iterator = iter(cycle(train_loader))