Search code examples
machine-learningdeep-learningpytorchcomputer-visiontorchvision

Augmentation in torch vision transform is not working as expected


I'm developing a CNN using pytorch. my model gives good accuracy on both training and test set without augmentation but I wanted to learn augmentation so I have used torchvision transforms for the augmentation and after applying the augmentation model started doing worst and loss is not at all decreasing. so I tried to debug and observed that the augmented image looks distorted/unexpected can somebody please help me solve this.

custom datset

class traindataset(Dataset):
    def __init__(self,data,train_end_idx,augmentation = None):
        '''
        data: data is a pandas dataframe generated from csv file where it has columns-> [name,labels,col 1,col2,...,col784]. shape of data->(10000, 786)
        
        '''
        self.data=data
        self.augmentation=augmentation
        self.train_end=train_end_idx
        self.target=self.data.iloc[:self.train_end,1].values
        self.image=self.data.iloc[:self.train_end,2:].values#contains full data
        
    def __len__(self):
        return len(self.target);
    def __getitem__(self,idx):
        
        self.target=self.target
        self.ima=self.image[idx].reshape(1,784) #only takes the selected index
        if self.augmentation is not None:
            self.ima = self.augmentation(self.ima)
        
        return torch.tensor(self.target[idx]),self.ima
                                        

Augmentation used

torchvision_transform = transforms.Compose([
    np.uint8,
    transforms.ToPILImage(),
    transforms.Resize((28,28)),
    transforms.RandomRotation([45,135]),
    transforms.ToTensor()
    ])  

Augmented image(PFA for the picture)

transformed=torchvision_transform(x)
plt.imshow(transformed.squeeze().numpy(), interpolation='nearest')
plt.show()
            

Normal image

x=data.iloc[:1,2:].values
plt.imshow(x.reshape(28,28), interpolation='nearest')
plt.show()

with augmentation

without augmentation

The first image is with augmentation and the second image is without augmentation. if you want you can play with the code here without downloading anything.


Solution

  • It seems like the transforms.Resize() function did not correctly reshape the tensor. Reshaping first seems to fix the issue and produce correct images (you did this step for the albumentations section).

    transformed = torchvision_transform(x.reshape(28,28))