Search code examples
pytorchpytorch-dataloaderdataloader

Modify PyTorch DataLoader to not mix files from different directories in batch


I want to load image sequences of a fixed length into batches of the same size (for example sequence length = batch size = 7).

There are multiple directories each with images from a sequence with varying number of images. The sequences from different directories are not related to each other.

With my current code, I can process several subdirectories, but if there are not enough images in one directory to fill a batch, the remaining images are taken from the next directory. I would like to avoid this.

Instead, a batch should be discarded if there are not enough images in the current directory and instead the batch should only be filled with images from the next directory. This way, I want to avoid mixing unrelated image sequences in the same batch. If a directory does not have enough images to create even a single batch, it should be skipped completely.

So for example with a sequence length/batch size of 7:

  • directory A has 15 images → 2 batches each with 7 images are created; the rest are ignored
  • directory B has 10 images → 1 batch with 7 images is created; the rest are ignored
  • directory C has 3 images → directory is skipped entirely

I’m still learning, but I think this can be done with a costum batch sampler? Unfortunately, I have some problems with this. Maybe someone can help me find a solution.

This is my current code:

class MainDataset(Dataset):

    def __init__(self, img_dir, use_folder_name=False):
        self.gt_images = self._load_main_dataset(img_dir)
        self.dataset_len = len(self.gt_images)
        self.use_folder_name = use_folder_name

    def __len__(self):
        return self.dataset_len

    def __getitem__(self, idx):
        img_dir = self.gt_images[idx]
        img_name = self._get_name(img_dir)

        gt = self._load_img(img_dir)

        # Skip non-image files
        if gt is None:
            return None

        gt = torch.from_numpy(gt).permute(2, 0, 1)

        return gt, img_name

    def _get_name(self, img_dir):
        if self.use_folder_name:
            return img_dir.split(os.sep)[-2]
        else:
            return img_dir.split(os.sep)[-1].split('.')[0]

    def _load_main_dataset(self, img_dir):
        if not (os.path.isdir(img_dir)):
            return [img_dir]

        gt_images = []
        for root, dirs, files in os.walk(img_dir):
            for file in files:
                if not is_valid_file(file):
                    continue
                gt_images.append(os.path.join(root, file))

        gt_images.sort()

        return gt_images

    def _load_img(self, img_path):

        gt_image = io.imread(img_path)
        gt_image_bd = getBitDepth(gt_image)
        gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)

        return gt_image


def is_valid_file(file_name: str):

    # Check if the file has a valid image extension
    valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']

    for ext in valid_image_extensions: 
        if file_name.lower().endswith(ext):
            return True

    return False



sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)

Solution

  • While using a batch sampler might be a good idea to have a generic custom dataset that you can sample differently, I would prefer a straightforward approach.

    I would construct a data structure in the init function that already contains all the image sequences you'll manipulate. The fact is that, currently, your Dataset class is lying as it says that the length of your dataset is equal to the number of image folders. This is not true as it depends on the number of images contained in the folder.

    Currently, your dataset only returns one image at a time while you are expecting sequences.

    Some information about the actual structure of the dataset is also missing from your question. Nevertheless, here is a proposal of Datatet class :

    class MainDataset(Dataset):
    
        def __init__(self, img_dir, use_folder_name=False, seq_len=7):
            self.seq_len = seq_len
            self.gt_images = self._load_main_dataset(img_dir)
            self.use_folder_name = use_folder_name
    
        def __len__(self):
            return len(self.gt_images)
    
        def __getitem__(self, idx):
            label, sequence = self.gt_images[idx]
    
            image_sequence = []
            for image_path in sequence:
                loaded_image = self._load_img(image_path)
                loaded_image = torch.from_numpy(loaded_image).permute(2, 0, 1)
    
                image_sequence.append(loaded_image)
    
    
            all_sequence = torch.stack(image_sequence, dim=0)
    
            # return a tensort of the sequence of images and the label 
            return all_sequence, label
    
        def _get_name(self, img_dir):
            if self.use_folder_name:
                return img_dir.split(os.sep)[-2]
            else:
                return img_dir.split(os.sep)[-1].split('.')[0]
    
        def _load_main_dataset(self, img_dir):
    
            # I don't really know why you don't throw an exception here.
            if not (os.path.isdir(img_dir)):
                return [img_dir]
    
            gt_images = []
    
            # Why using walk ? What is the structure of the dataset ?
            for root, dirs, files in os.walk(img_dir):
    
                # This variable accumulates the images in the sequence
                image_sequence = []
    
                for file in files:
                    if not is_valid_file(file):
                        continue
    
                    img_path = os.path.join(root, file)
                    image_sequence.append(img_path)
    
                    if len(image_sequence) == self.seq_len:
                        sorted_sequence = image_sequence.sort()
                        label = self._get_name(sorted_sequence)
    
                        gt_images.append((label,sorted_sequence))
                        image_sequence = []
    
            # Now gt_images is a list of tuples (label, sequence)
            return gt_images
    
        def _load_img(self, img_path):
    
            gt_image = io.imread(img_path)
            gt_image_bd = getBitDepth(gt_image)
            gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)
    
            return gt_image
    
    
    def is_valid_file(file_name: str):
    
        # Check if the file has a valid image extension
        valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']
    
        for ext in valid_image_extensions: 
            if file_name.lower().endswith(ext):
                return True
    
        return False
    
    
    
    sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
    sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)