I am making a neural network in Python using the DVS128Gesture dataset. I want to transform the default 128x128 trinary frames to 32x32 binary frames, but when I try to use torchvision.transform in the dataset, I am getting this error:
img should be PIL Image. Got <class 'numpy.lib.npyio.NpzFile'>
My code:
import torch
import torchvision
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
train_data = DVS128Gesture(root_dir, train=True, data_type='event',
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(32),
torchvision.transforms.Normalize((0.0,), (0.8,)),
torchvision.transforms.ToTensor()
]))
test_data = DVS128Gesture(root_dir, train=False, data_type='event',
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(32),
torchvision.transforms.Normalize((0.0,), (0.8,)),
torchvision.transforms.ToTensor()
]))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=bs, shuffle=True)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape
I have done the same with the MNIST dataset and everything worked as expected. I think the problem is that I use torchvision.transform in DVS128Gesture, but I am not sure what else I can use.
The same with MNIST:
train_data = torchvision.datasets.MNIST(root_dir, train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.0,), (0.8,))
]))
test_data = torchvision.datasets.MNIST(root_dir, train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.0,), (0.8,))
]))
What am I doing wrong?
Stack trace of error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [10], line 2
1 examples = enumerate(test_loader)
----> 2 batch_idx, (example_data, example_targets) = next(examples)
3 example_data.shape
File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:681, in _BaseDataLoaderIter.__next__(self)
678 if self._sampler_iter is None:
679 # TODO(https://github.com/pytorch/pytorch/issues/76750)
680 self._reset() # type: ignore[call-arg]
--> 681 data = self._next_data()
682 self._num_yielded += 1
683 if self._dataset_kind == _DatasetKind.Iterable and \
684 self._IterableDataset_len_called is not None and \
685 self._num_yielded > self._IterableDataset_len_called:
File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:721, in _SingleProcessDataLoaderIter._next_data(self)
719 def _next_data(self):
720 index = self._next_index() # may raise StopIteration
--> 721 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
722 if self._pin_memory:
723 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
47 def fetch(self, possibly_batched_index):
48 if self.auto_collation:
---> 49 data = [self.dataset[idx] for idx in possibly_batched_index]
50 else:
51 data = self.dataset[possibly_batched_index]
File p:\Programs\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in <listcomp>(.0)
47 def fetch(self, possibly_batched_index):
48 if self.auto_collation:
---> 49 data = [self.dataset[idx] for idx in possibly_batched_index]
50 else:
51 data = self.dataset[possibly_batched_index]
File p:\Programs\Anaconda3\lib\site-packages\torchvision\datasets\folder.py:232, in DatasetFolder.__getitem__(self, index)
230 sample = self.loader(path)
231 if self.transform is not None:
--> 232 sample = self.transform(sample)
233 if self.target_transform is not None:
234 target = self.target_transform(target)
File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py:94, in Compose.__call__(self, img)
92 def __call__(self, img):
93 for t in self.transforms:
---> 94 img = t(img)
95 return img
File p:\Programs\Anaconda3\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py:349, in Resize.forward(self, img)
341 def forward(self, img):
342 """
343 Args:
344 img (PIL Image or Tensor): Image to be scaled.
(...)
347 PIL Image or Tensor: Rescaled image.
348 """
--> 349 return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\functional.py:430, in resize(img, size, interpolation, max_size, antialias)
428 warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
429 pil_interpolation = pil_modes_mapping[interpolation]
--> 430 return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
432 return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)
File p:\Programs\Anaconda3\lib\site-packages\torchvision\transforms\functional_pil.py:249, in resize(img, size, interpolation, max_size)
240 @torch.jit.unused
241 def resize(
242 img: Image.Image,
(...)
245 max_size: Optional[int] = None,
246 ) -> Image.Image:
248 if not _is_pil_image(img):
--> 249 raise TypeError(f"img should be PIL Image. Got {type(img)}")
250 if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
251 raise TypeError(f"Got inappropriate size arg: {size}")
TypeError: img should be PIL Image. Got <class 'numpy.lib.npyio.NpzFile'>
The error probably comes from the Resize
transform (can you provide more details on the stack trace of the error ?).
Resize
is a image specific tranform, expecting a PIL image (or a torch Tensor, see the transform documentation), while your dataset DVS128Gesture
outputs another object type.