Search code examples
pythoncomputer-visionpytorchmnistdcgan

How to create a Pytorch Dataset from .pt files?


I have transformed MNIST images saved as .pt files in a folder in Google drive. I'm writing my Pytorch code in Colab.

I would like to use these files, and create a Dataset that stores these images as Tensors. How can I do this?

Transforming images during training took too long. Hence, transformed them and saved them all as .pt files. I just want to load them back as a dataset and use them in my model.


Solution

  • The approach you are following to save images is indeed a good idea. In such a case, you can simply write your own Dataset class to load the images.

    from torch.utils.data import Dataset, DataLoader
    from torch.utils.data.sampler import RandomSampler
    
    class ReaderDataset(Dataset):
        def __init__(self, filename):
            # load the images from file
    
        def __len__(self):
            # return total dataset size
    
        def __getitem__(self, index):
            # write your code to return each batch element
    

    Then you can create Dataloader as follows.

    train_dataset = ReaderDataset(filepath)
    train_sampler = RandomSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=batchify,
        pin_memory=args.cuda,
        drop_last=args.parallel
    )
    # args is a dictionary containing parameters
    # batchify is a custom function that prepares each mini-batch