Search code examples
pythonpytorchflattentorchvision

Torchvision.transforms implementation of Flatten()


I have grayscale images, but I need transform it to a dataset of 1d vectors How can I do this? I could not find a suitable method in transforms:

train_dataset = torchvision.datasets.ImageFolder(root='./data',train=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root='./data',train=False, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4, shuffle=False)

Solution

  • Here's how you can do it using Lambda

    import torch
    from torchvision.datasets import MNIST
    import torchvision.transforms as T
    
    # without flatten
    dataset = MNIST(root='.', download=True, transform=T.ToTensor())
    print(dataset[0][0].shape)
    # >>> torch.Size([1, 28, 28])
    
    # with flatten (using Lambda, but you can do it in many other ways)
    dataset_flatten = MNIST(root='.', download=True, transform=T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x))]))
    print(dataset_flatten[0][0].shape)
    # >>> torch.Size([784])