Search code examples
pythonneural-networkpytorch

Transform DVS128Gesture dataset


I am making a neural network in Python using the DVS128Gesture dataset. I want to transform the default 128x128 trinary frames to 32x32 binary frames, but when I try to use torchvision.transform in the dataset, I am getting this error:

img should be PIL Image. Got <class 'numpy.lib.npyio.NpzFile'>

My code:

import torch
import torchvision
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture

train_data = DVS128Gesture(root_dir, train=True, data_type='event',
                            transform=torchvision.transforms.Compose([
                                torchvision.transforms.Resize(32),
                                torchvision.transforms.Normalize((0.0,), (0.8,)),
                                torchvision.transforms.ToTensor()
                            ]))
test_data = DVS128Gesture(root_dir, train=False, data_type='event',
                           transform=torchvision.transforms.Compose([
                               torchvision.transforms.Resize(32),
                               torchvision.transforms.Normalize((0.0,), (0.8,)),
                               torchvision.transforms.ToTensor()
                           ]))

train_loader = torch.utils.data.DataLoader(train_data, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=bs, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape

I have done the same with the MNIST dataset and everything worked as expected. I think the problem is that I use torchvision.transform in DVS128Gesture, but I am not sure what else I can use.

The same with MNIST:

train_data = torchvision.datasets.MNIST(root_dir, train=True, download=True,
                            transform=torchvision.transforms.Compose([
                                torchvision.transforms.Resize(28),
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize((0.0,), (0.8,))
                            ]))
test_data = torchvision.datasets.MNIST(root_dir, train=False, download=True,
                           transform=torchvision.transforms.Compose([
                               torchvision.transforms.Resize(28),
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.0,), (0.8,))
                           ]))

What am I doing wrong?

Stack trace of error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [10], line 2
      1 examples = enumerate(test_loader)
----> 2 batch_idx, (example_data, example_targets) = next(examples)
      3 example_data.shape

File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:681, in _BaseDataLoaderIter.__next__(self)
    678 if self._sampler_iter is None:
    679     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    680     self._reset()  # type: ignore[call-arg]
--> 681 data = self._next_data()
    682 self._num_yielded += 1
    683 if self._dataset_kind == _DatasetKind.Iterable and \
    684         self._IterableDataset_len_called is not None and \
    685         self._num_yielded > self._IterableDataset_len_called:

File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:721, in _SingleProcessDataLoaderIter._next_data(self)
    719 def _next_data(self):
    720     index = self._next_index()  # may raise StopIteration
--> 721     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    722     if self._pin_memory:
    723         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     47 def fetch(self, possibly_batched_index):
     48     if self.auto_collation:
---> 49         data = [self.dataset[idx] for idx in possibly_batched_index]
     50     else:
     51         data = self.dataset[possibly_batched_index]

File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in <listcomp>(.0)
     47 def fetch(self, possibly_batched_index):
     48     if self.auto_collation:
---> 49         data = [self.dataset[idx] for idx in possibly_batched_index]
     50     else:
     51         data = self.dataset[possibly_batched_index]

File p:\Programs\Anaconda3\lib\site-packages\torchvision\datasets\folder.py:232, in DatasetFolder.__getitem__(self, index)
    230 sample = self.loader(path)
    231 if self.transform is not None:
--> 232     sample = self.transform(sample)
    233 if self.target_transform is not None:
    234     target = self.target_transform(target)

File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py:94, in Compose.__call__(self, img)
     92 def __call__(self, img):
     93     for t in self.transforms:
---> 94         img = t(img)
     95     return img

File p:\Programs\Anaconda3\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py:349, in Resize.forward(self, img)
    341 def forward(self, img):
    342     """
    343     Args:
    344         img (PIL Image or Tensor): Image to be scaled.
   (...)
    347         PIL Image or Tensor: Rescaled image.
    348     """
--> 349     return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)

File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\functional.py:430, in resize(img, size, interpolation, max_size, antialias)
    428         warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
    429     pil_interpolation = pil_modes_mapping[interpolation]
--> 430     return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
    432 return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)

File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\functional_pil.py:249, in resize(img, size, interpolation, max_size)
    240 @torch.jit.unused
    241 def resize(
    242     img: Image.Image,
   (...)
    245     max_size: Optional[int] = None,
    246 ) -> Image.Image:
    248     if not _is_pil_image(img):
--> 249         raise TypeError(f"img should be PIL Image. Got {type(img)}")
    250     if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
    251         raise TypeError(f"Got inappropriate size arg: {size}")

TypeError: img should be PIL Image. Got <class 'numpy.lib.npyio.NpzFile'>

Solution

  • The error probably comes from the Resize transform (can you provide more details on the stack trace of the error ?).

    Resize is a image specific tranform, expecting a PIL image (or a torch Tensor, see the transform documentation), while your dataset DVS128Gesture outputs another object type.