Search code examples
python-3.xpytorchtorchvisionpytorch-dataloader

torchvision CIFAR10 augmentation gives TypeError: Unexpected type <class 'numpy.ndarray'>


For CIFAR-10 data augmentations using torchvision transforms. torchvision version: '0.15.2+cu117' and torch version: 2.0.1+cu117

strength = 0.2
color_jitter = transforms.ColorJitter(
    brightness = 0.8 * strength, contrast = 0.8 * strength,
    saturation = 0.8 * strength, hue = 0.2 * strength
)

rand_color_jitter = transforms.RandomApply([color_jitter], p = 0.8)

rand_gray = transforms.RandomGrayscale(p = 0.2)

color_distortion = transforms.Compose(
    [
        rand_color_jitter,
        rand_gray
    ]
)

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_data_augmentation = transforms.Compose(
    [
        transforms.RandomResizedCrop(size = 32, scale = (0.14, 1.0)),
        transforms.GaussianBlur(kernel_size = (3, 3), sigma = (0.1, 2.0)),
        color_distortion,
        # v2.ToImage(),
        # v2.ToDtype(torch.float32, scale = True),
        # v2.ToDtype(torch.float32),
        # v2.Normalize(),
        transforms.ToTensor(),
        transforms.Normalize(mean = mean, std = std)
    ]
)

I use these augmentations inside a datasets class:

class Cifar10Dataset(torchvision.datasets.CIFAR10):
    def __init__(
        self, root = "~/data/cifar10",
        train = True, download = True,
        transform = None
    ):
        super().__init__(
            root = root, train = train,
            download = download, transform = transform
        )

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            image = self.transform(image)
            # image = transformed["image"]
            
            # Randomly select 0, 1, 2 or 3 for image rotation-
            ang = np.random.randint(low = 0, high = 4, size = None)
            image = TF.rotate(img = image, angle = ang * 90)
            
        return image, ang


train_dataset = Cifar10Dataset(
    root = some_path, train = True,
    download = True, transform = train_data_augmentation
)

train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset, batch_size = 128,
    shuffle = True
)

x, y = next(iter(training_loader))

Which throws the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[124], line 1
----> 1 x, y = next(iter(train_loader))

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
    675 def _next_data(self):
    676     index = self._next_index()  # may raise StopIteration
--> 677     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    678     if self._pin_memory:
    679         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51, in <listcomp>(.0)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

Cell In[121], line 16, in Cifar10Dataset.__getitem__(self, index)
     13 image, label = self.data[index], self.targets[index]
     15 if self.transform is not None:
---> 16     image = self.transform(image)
     17     # image = transformed["image"]
     18     
     19     # Randomly select 0, 1, 2 or 3 for image rotation-
     20     ang = np.random.randint(low = 0, high = 4, size = None)

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/transforms.py:95, in Compose.__call__(self, img)
     93 def __call__(self, img):
     94     for t in self.transforms:
---> 95         img = t(img)
     96     return img

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/transforms.py:979, in RandomResizedCrop.forward(self, img)
    971 def forward(self, img):
    972     """
    973     Args:
    974         img (PIL Image or Tensor): Image to be cropped and resized.
   (...)
    977         PIL Image or Tensor: Randomly cropped and resized image.
    978     """
--> 979     i, j, h, w = self.get_params(img, self.scale, self.ratio)
    980     return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/transforms.py:940, in RandomResizedCrop.get_params(img, scale, ratio)
    927 @staticmethod
    928 def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
    929     """Get parameters for ``crop`` for a random sized crop.
    930 
    931     Args:
   (...)
    938         sized crop.
    939     """
--> 940     _, height, width = F.get_dimensions(img)
    941     area = height * width
    943     log_ratio = torch.log(torch.tensor(ratio))

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/functional.py:78, in get_dimensions(img)
     75 if isinstance(img, torch.Tensor):
     76     return F_t.get_dimensions(img)
---> 78 return F_pil.get_dimensions(img)

File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py:31, in get_dimensions(img)
     29     width, height = img.size
     30     return [channels, height, width]
---> 31 raise TypeError(f"Unexpected type {type(img)}")

TypeError: Unexpected type <class 'numpy.ndarray'>

Shouldn't RandomResizedCrop() be able to convert np array to torch tensor? What am I missing?


Solution

  • No, torch.tensors and numpy.arrays are not fully interchangeable, even though they can be used as such in many cases. (As far as I know, this has something to do with the fact that torch needs to handle ownership across many devices.)

    The documentation for RandomResizedCrop does state that the only accepted input types are PIL.Image and torch.Tensor. So you need to convert your images to tensors first.

    What you usually want to do is run as many of the augmentations on ByteTensors as possible, and then do scaling and normalization in the end. Using v2 transforms, you're probably looking for something like this:

    v2.Compose([
      v2.ToImageTensor(), # [H,W,C] NDArray[uint8] -> [C,H,W] ByteTensor 
      v2.RandomResizedCrop(),
      v2.GaussianBlur(),
      v2.ColorJitter(),
      v2.ConvertDtype(), # ByteTensor (0, 255) -> FloatTensor (0, 1)
      v2.Normalize(),
    ])
    

    or for torchvision versions >=0.16

    v2.Compose([
      v2.ToImage(), # [H,W,C] NDArray[uint8] -> [C,H,W] ByteTensor 
      v2.RandomResizedCrop(),
      v2.GaussianBlur(),
      v2.ColorJitter(),
      v2.ToDtype(), # ByteTensor (0, 255) -> FloatTensor (0, 1)
      v2.Normalize(),
    ])