Search code examples
pytorchautoencoder

pytorch: how can I use picture as label in dataloader?


I want to do some image reconstruction using autoencoders in pytorch, however, I didn't find a way to use image as label for an input image.(the label image is different from original ones)

I've tried the image folder method, but I think that's for classfication and I am currently unable to come up with one solution. Should I create a custom dataset for this...

Thanks in advance!


Solution

  • Write your custom Dataset, below is a simple example.

    
    import torch.utils.data.Dataset as Dataset
    
    class CustomDataset(Dataset):
    
        def __init__(self, input_imgs, label_imgs, transform):
    
            self.input_imgs = input_imgs
            self.label_imgs = label_imgs
            self.transform = transform
    
        def __len__(self):
            return len(self.input_imgs)
    
    
        def __getitem__(self, idx):
            input_img, label_img = self.input_imgs[idx], self.label_imgs[idx]
            return self.transform(input_img), self.transform(label_img)
    
    

    And then, pass it to Dataloader:

    dataloader = DataLoader(CustomDataset)