Search code examples
neural-networkpytorchdataloader

Pytorch transformation on MNIST dataset


I currently have a project with Weak Supervision where I need to put a "masking" in front of a dataset. My issue right now is that I don't exactly know how to do it. Let me explain further with some code and images.

I am using the MNIST dataset that I have to edit in this way. As you can see a middle square is cut out. The code below is used to edit the MNIST using a for loop.

for i in range(int(image_size/2-5),int(image_size/2+3)):
   for j in range(int(image_size/2-5),int(image_size/2+3)):
      image[i][j] = 0

However, I am currently not sure how I should use this in a dataloader transform. The code for the dataloader and transform is shown here:

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=4
)

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))

So is there a straightforward way to apply the transform to the full dataset in the torchvision.transforms.Compose?


Solution

  • You can define any custom transformation and as a function and use torchvision.transforms.Lambda in the transformation pipeline.

    def erase_middle(image: torch.Tensor) -> torch.Tensor:
        for i in range(int(image_size/2-5),int(image_size/2+3)):
            for j in range(int(image_size/2-5),int(image_size/2+3)):
                image[:, i, j] = 0
        return image
    
    transform = torchvision.transforms.Compose(
        [
            # First transform it to a tensor
            torchvision.transforms.ToTensor(),
            # Then erase the middle
            torchvision.transforms.Lambda(erase_middle),
        ]
    )
    

    erase_middle can be made more generic, such that it works for images with varying sizes and that aren't necessarily square.

    def erase_middle(image: torch.Tensor) -> torch.Tensor:
        _, height, width = image.size()
        x_start = width // 2 - 5
        x_end = width // 2 + 3
        y_start = height // 2 - 5
        y_end = height // 2 + 3
        # Using slices achieves the same as the for loops
        image[:, y_start:y_end, x_start:x_end] = 0
        return image