I am writing a custom pytorch dataset. In __init__
the dataset object loads a file that contains certain data. But in my program I only wish to access part of the data (to achieve train/valid cut, if it helps). Originally I thought this behavior was controlled by overriding __len__
, but it turned out that modifying __len__
does not help. A simple example is as follows:
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)
The output is 0 through 9, while the desired behavior would be 0 through 4.
How does this dataset object know that the enumeration has hit the end when used in a for loop like this? Any other method to achieve a similar effect is also welcome.
You can use torch.utils.data.Subset
to get subset of your data
top_five = torch.utils.data.Subset(ds, indices=range(5)) # Get first five items
for i, x in enumerate(top_five):
print(i)
0
1
2
3
4
enumerate
in loop will return item until it gets StopIteration
exception.
len(ds) # Returned modified length
5
# `enumerate` will call `next` method on iterable each time in loop.
# and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate in iterable
Output:
tensor([-1.5952, -0.0826])
tensor([-2.2254, 0.2461])
tensor([-0.8268, 0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
10 print(next(iter_ds))
11
---> 12 print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate
StopIteration: