Search code examples
pytorchdatasettorchtorchvision

pytorch: Merge three datasets with predefined and custom datasets


I am training an AI model to recognize handwritten hangul characters along with English characters and numbers. It means that I require three datasets custom korean character dataset and other datasets.

I have three datasets and now I am merging three datasets but when I print the train_set path it shows MJSynth only which is wrong.

긴장_1227682.jpg is in my custom korean dataset not in MJSynth

Code

custom_train_set = RecognitionDataset(
            parts[0].joinpath("images"),
            parts[0].joinpath("labels.json"),
            img_transforms=Compose(
                [
                    T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
                    # Augmentations
                    T.RandomApply(T.ColorInversion(), 0.1),
                    ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
                ]
            ),
        )
        if len(parts) > 1:
            for subfolder in parts[1:]:
                custom_train_set.merge_dataset(
                    RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))
                )

        train_set = MJSynth(
            train=True,
            img_folder='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px',
            label_path='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt',
            img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
        )

        _train_set = SynthText(
            train=True,
            recognition_task=True,
            download=True,  # NOTE: download can take really long depending on your bandwidth
            img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
        )
        train_set.data.extend([(np_img, target) for np_img, target in _train_set.data])
        train_set.data.extend([(np_img, target) for np_img, target in custom_train_set.data])

Traceback

Traceback (most recent call last):
  File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 485, in <module>
    main(args)
  File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 396, in main
    fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
  File "/media/cvpr/CM_22/doctr/references/recognition/train_pytorch.py", line 118, in fit_one_epoch
    for images, targets in progress_bar(train_loader, parent=mb):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/fastprogress/fastprogress.py", line 50, in __iter__
    raise e
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/fastprogress/fastprogress.py", line 41, in __iter__
    for i,o in enumerate(self.gen):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
    return self._process_data(data)
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
    data.reraise()
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/media/cvpr/CM_22/doctr/doctr/datasets/datasets/base.py", line 48, in __getitem__
    img, target = self._read_sample(index)
  File "/media/cvpr/CM_22/doctr/doctr/datasets/datasets/pytorch.py", line 37, in _read_sample
    else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32)
  File "/media/cvpr/CM_22/doctr/doctr/io/image/pytorch.py", line 52, in read_img_as_tensor
    pil_img = Image.open(img_path, mode="r").convert("RGB")
  File "/home/cvpr/anaconda3/envs/pytesseract/lib/python3.9/site-packages/PIL/Image.py", line 2912, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/긴장_1227682.jpg'

Solution

  • You have to change the arrangement of your three datasets because you are using docTR library, merging datasets in this library is different compared to normal ConcatenateDataset in PyTorch.

    Print the size of the individual dataset so that you can check the actual length of a dataset and overall dataset size.

    mjsynth_train = MJSynth(
                train=True,
                img_folder='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px',
                label_path='/media/cvpr/CM_22/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt',
                img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
            )
    
            print("MJSynth dataset size is", len(mjsynth_train))
    
            synth_train = SynthText(
                recognition_task=True,
                download=True,  # NOTE: download can take really long depending on your bandwidth
                img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
            )
    
            print("SynthText dataset size is", len(synth_train))
    
            train_set.data.extend([(np_img, target) for np_img, target in mjsynth_train.data])
            train_set.data.extend([(np_img, target) for np_img, target in synth_train.data])
    
            print("Overall dataset size is", len(train_set))