Search code examples
pythonfor-loopiteratorpytorch

How does a pytorch dataset object know whether it has hit the end when used in a for loop?


I am writing a custom pytorch dataset. In __init__ the dataset object loads a file that contains certain data. But in my program I only wish to access part of the data (to achieve train/valid cut, if it helps). Originally I thought this behavior was controlled by overriding __len__, but it turned out that modifying __len__ does not help. A simple example is as follows:

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in enumerate(ds):
    print(i)

The output is 0 through 9, while the desired behavior would be 0 through 4.

How does this dataset object know that the enumeration has hit the end when used in a for loop like this? Any other method to achieve a similar effect is also welcome.


Solution

  • You can use torch.utils.data.Subset to get subset of your data

    top_five = torch.utils.data.Subset(ds, indices=range(5))  # Get first five items
    for i, x in enumerate(top_five):
        print(i)
    0
    1
    2
    3
    4
    

    enumerate in loop will return item until it gets StopIteration exception.

    len(ds)         # Returned modified length
    5
    
    # `enumerate` will call `next` method on iterable each time in loop.
    #  and When no more data available a StopIteration exception is raised instead.
    iter_ds = iter(ds)
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    print(next(iter_ds))
    
    print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate in iterable
    

    Output:

    tensor([-1.5952, -0.0826])
    tensor([-2.2254,  0.2461])
    tensor([-0.8268,  0.1956])
    tensor([ 0.3157, -0.3403])
    tensor([0.8971, 1.1255])
    tensor([0.3922, 1.3184])
    tensor([-0.4311, -0.8898])
    tensor([ 0.1128, -0.5708])
    tensor([-0.5403, -0.9036])
    tensor([0.6550, 1.6777])
    
    ---------------------------------------------------------------------------
    StopIteration                             Traceback (most recent call last)
    <ipython-input-99-7a9910e027c3> in <module>
         10 print(next(iter_ds))
         11 
    ---> 12 print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate
    
    StopIteration: