Search code examples
pytorchtorchvisionpytorch-dataloader

PyTorch random_split() is returning wrong sized loader


I have a custom dataset loader for my dataset. I want to split the dataset into 70% train data, 20% validation data, and 10% test data. I have 16,488 data. So, my train data is supposed to be 11,542. But it's becoming 770 train data, 220 validation data, and 110 test data. I've tried but couldn't figure out the problem.

class Dataset(Dataset):
    def __init__(self, directory, transform, preload=False, device: torch.device = torch.device('cpu'), **kwargs):
        self.device = device
        self.directory = directory
        self.transform = transform
        self.labels = []
        self.images = []
        self.preload = preload

        for i, file in enumerate(os.listdir(self.directory)):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)
            
            if file_labels is None:
                continue
                
            if self.preload:
                image = Image.open(os.path.join(self.directory, file)).convert('RGB')
                if self.transform is not None:
                    image = self.transform(image).to(self.device)
            else:
                image = os.path.join(self.directory, file)

            self.images.append(image)
            
            gender_to_class_id = {
                'm': 0, 
                'f': 1
            }
            gender = gender_to_class_id[file_labels['gender']]
            age = int(file_labels['age'])
            self.labels.append({
                'age': age,
                'gender': gender
            })
        pass

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.images[idx]

        if not self.preload:
            image = Image.open(image).convert('RGB')
            if self.transform is not None:
                image = self.transform(image).to(self.device)

        labels = {
            'age': self.labels[idx]['age'], 
            'gender': self.labels[idx]['gender'],
        }
        return image.to(self.device), labels
    
    def get_loaders(self, transform, train_size=0.7, validate_size=0.2, test_size=0.1, batch_size=15, **kwargs):
        if round(train_size + validate_size + test_size, 1) > 1.0:
            sys.exit("Sum of the percentages should be less than 1. it's " + str(
                train_size + validate_size + test_size) + " now!")

        train_len = int(len(self) * train_size)
        validate_len = int(len(self) * validate_size)
        test_len = int(len(self) * test_size)
        others_len = len(self) - train_len - validate_len - test_len

        self.trainDataset, self.validateDataset, self.testDataset, _ = torch.utils.data.random_split(
            self, [train_len, validate_len, test_len, others_len]
        )

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset, batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

Solution

  • It seems that you are giving

    batch_size=15
    

    As a dataloader is iterable, it maybe simply giving you the len() of the 1st batch. It also explains why you are getting train data = 770, where it is supposed to be 11,542. Because,

    16488 / 15 * 0.7 = 769.44 ≈ 770
    

    Assigning batch_size = 1 should do the trick.

    16488 / 1 * 0.7 = 11541.6 ≈ 11542