Search code examples
pytorchtorchvisionpytorch-dataloader

Alternate training between two datasets


I am trying to use augmented and not augmented datasets in each epoch(for example: augmented in one epoch not augmented in a different epoch) but I couldn't figure out how to do it. My approach was loading DataLoader in each epoch again and again but I think it's wrong. Because when I print indexes in __getitem__ in my Dataset class, there is a lot of duplicates indices.

Here is my code for training:

for i in range(epoch):

    train_loss = 0.0
    valid_loss = 0.0
    since = time.time()
    scheduler.step(i)
    lr = scheduler.get_lr()

    #######################################################
    #Training Data
    #######################################################

    model_test.train()
    k = 1
    tx=""
    lx=""
    random_ = random.randint(0,1)
    print("QPQPQPQPQPQPQPQPPQPQ")
    print(random_)
    print("QPQPQPQPQPQPQPQPPQPQ")
    if random_== 0:
            tx = torchvision.transforms.Compose([
                #  torchvision.transforms.Resize((128,128)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

            lx = torchvision.transforms.Compose([
                    #  torchvision.transforms.Resize((128,128)),
                    torchvision.transforms.Grayscale(),
                    torchvision.transforms.ToTensor(),
                    # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
                ])
    else:
            tx = torchvision.transforms.Compose([
                #  torchvision.transforms.Resize((128,128)),
                
                torchvision.transforms.CenterCrop(96),
                torchvision.transforms.RandomRotation((-10, 10)),
                # torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            lx = torchvision.transforms.Compose([
                    #  torchvision.transforms.Resize((128,128)),
                    
                    torchvision.transforms.CenterCrop(96),
                    torchvision.transforms.RandomRotation((-10, 10)),
                    torchvision.transforms.Grayscale(),
                    torchvision.transforms.ToTensor(),
                    # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
                ])
    Training_Data = Images_Dataset_folder(t_data,
                                      l_data,tx,lx)
    train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler,
                                           num_workers=num_workers, pin_memory=pin_memory,)

    valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler,
                                           num_workers=num_workers, pin_memory=pin_memory,)

    
    for x,y in train_loader:

        x, y = x.to(device), y.to(device)
       
        #If want to get the input images with their Augmentation - To check the data flowing in net
        input_images(x, y, i, n_iter, k)

       # grid_img = torchvision.utils.make_grid(x)
        #writer1.add_image('images', grid_img, 0)

       # grid_lab = torchvision.utils.make_grid(y)

        opt.zero_grad()

        y_pred = model_test(x)
        lossT = calc_loss(y_pred, y)     # Dice_loss Used

        train_loss += lossT.item() * x.size(0)
        lossT.backward()
      #  plot_grad_flow(model_test.named_parameters(), n_iter)
        opt.step()
        x_size = lossT.item() * x.size(0)
        k = 2

Here is my code for the dataset:

    def __init__(self, images_dir, labels_dir, transformI=None, 
        transformM=None):
        self.images = sorted(os.listdir(images_dir))
        self.labels = sorted(os.listdir(labels_dir))
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transformI = transformI
        self.transformM = transformM
        self.tx=self.transformI
        self.lx=self.transformM

        

    def __len__(self):

        return len(self.images)

    def __getitem__(self, i):
      
        with open("/content/x.txt", "a") as o:
            o.write(str(i)+"\n")
        i1 = Image.open(self.images_dir + self.images[i])
        l1 = Image.open(self.labels_dir + self.labels[i])

        seed = np.random.randint(0, 2 ** 32)  # make a seed with numpy generator

        # apply this seed to img tranfsorms
        random.seed(seed)
        torch.manual_seed(seed)
        
        img = self.tx(i1) 

        # apply this seed to target/label tranfsorms
        random.seed(seed)
        torch.manual_seed(seed)
        label = self.lx(l1)

        return img, label

How can i achieve what i want? Thanks in advance.


Solution

  • Instantiating a dataset and data loader for each epoch doesn't seem to be the way to go. Instead, you may want to instantiate two sets of dataset + data loader, each one with its corresponding augmentation pipeline.

    Here is an example to give you a basic frame:

    Start by defining the transformation pipelines inside of the dataset itself:

    class Images_Dataset_folder(Dataset):
        def __init__(self, images_dir, labels_dir, augment=False):
            super().__init__()
            self.tx, self.lx = self._augmentations() if augment else self._no_augmentations()
    
        def __len__(self):
            pass
            
        def __getitem__(self, i):
            pass
    
        def _augmentations(self):
            tx = T.Compose([
                T.ToTensor(),
                T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    
            lx = T.Compose([
                T.Grayscale(),
                T.ToTensor()])
    
            return tx, lx
            
        def _no_augmentations(self):
            tx = T.Compose([
                    T.CenterCrop(96),
                    T.RandomRotation((-10, 10)),
                    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    T.ToTensor(),
                    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    
            lx = T.Compose([
                T.CenterCrop(96),
                T.RandomRotation((-10, 10)),
                T.Grayscale(),
                T.ToTensor()])
    
            return tx, lx
    

    Then you can construct your training loop as:

    # augmented images dataset
    aug_trainset = Images_Dataset_folder(t_data, l_data, augment=True)
    aug_dataloader= DataLoader(aug_trainset, batch_size=batch_size)
    
    # unaugmented images dataset
    unaug_trainset = Images_Dataset_folder(t_data, l_data, augment=False)
    unaug_dataloader = DataLoader(unaug_trainset, batch_size=batch_size)
    
    # on each epoch you go through the
    for i in range(epochs//2):
        # call train loop on augmented data loader
        train(model, aug_dataloader)
    
        # call train loop with un-augmented data loader
        train(model, unaug_dataloader )
    

    This being said, you will essentially loop over the dataset twice: once on unaugmented images and a second time around with augmented images.

    If you want to only iterate only once, then the easiest solution I can come up with is having a random flag inside the __getitem__ that would decide whether or not the current image needs to get augmented.


    Side note: you wouldn't want to use train data in your validation set!