Search code examples
pythonpytorch

Is the DataLoader object an iterable object?


This is the code

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math

# creating a custom class for our dataset, which inherits from Dataset.
class WineDataset(Dataset):

    # this function is used for data loading
    def __init__(self):
      # data loading
      xy = np.loadtxt('./wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
      self.x = torch.from_numpy(xy[:, 1:])  # the first column is the output label
      self.y = torch.from_numpy(xy[:, [0]]) # n_samples, 1
      self.n_samples = xy.shape[0]

    # this function allows indexing in our dataset
    def __getitem__(self, index):
      return self.x[index], self.y[index] # the function returns a tuple.

    # this allows us to call len on our dataset.
    def __len__(self):
      return self.n_samples

dataset = WineDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)

dataiter = iter(dataloader)
data = next(dataiter)
features, labels = data
print(features, labels)

My question is that since, we can already call the enumerate method directly on the dataloader, does it mean that the dataloader object is an iterable ? If this is true, then calling iter(dataloader) would be the same as creating an iterator object from an iterator object ?

I'm a bit confused about this please help me out.

I need to know what the enumerate method is doing behind the scenes when dataloader is passed as an argument. Also need to know what iter(dataloader) is doing.


Solution

  • Iterable is something which implements __iter__ method. Iterator is something which implements __next__ method. Both iter() and enumerate() call the __iter___ method of the class. for example

    class A: # this is an iterable
        def __iter__(self):
            print ('iter called at A')
            return B()
        
    class B: # this is an iterator
        def __next__(self):
            print( 'next called at B')
            return 1
    

    Note, any object of class B is an iterator because it implements __next__ but its not an iterable becuase it doesn't have __iter__ method. Similarly, any object of class A is an iterable but not an iterator.

    Run it,

    a = A()
    

    create an iterator

    b = iter(a)
    print(f'{type(b)=}')
    
    """
    iter called at A
    type(b)=<class '__main__.B'>
    """
    

    calling next() on iterator b

    next(b)
    
    """
    next called at B
    1
    """
    

    can't call next() on iterable a

    next(a)
    
    """
    TypeError: 'A' object is not an iterator
    """
    

    We can do a for loop on a

    for i in a:
        print(i)
        break
    
    """
    iter called at A
    next called at B
    1
    """
    

    Can't do a for loop on b

    for i in b:
        print(i)
        break
    
    """
    TypeError: 'B' object is not iterable
    """
    

    Now, call enumerate

    c = enumerate(a)
    print(f'{type(c)=}')
    
    """
    iter called at A
    type(c)=<class 'enumerate'>
    """
    

    Can do a for loop on c as well as call next()

    next(c)
    
    """
    next called at B
    (0, 1)
    """
    
    for i in c:
        print(i)
        break
    
    """
    next called at B
    (1, 1)
    """
    

    So the enumerate class is both an iterator and an iterable because it has both __iter__ and __next__ methods. You can check this by calling dir(c).

    When we call enumerate on a Dataloader, its __iter__ method is called. Looking at the signature of __iter__ function in pytorch source code:

    class DataLoader(Generic[T_co]):
    .
    .
        def __iter__(self) -> '_BaseDataLoaderIter':
    

    This _BaseDataLoaderIter class implements both __iter__ and __next__, so its both an iterable and an iterator.

    class _BaseDataLoaderIter(object):
    .
    .
        def __iter__(self) -> '_BaseDataLoaderIter':
            return self
    
        def __next__(self) -> Any:
        .
        .
            return data
    

    So you can call both enumerate() and iter() on Dataloader and can even do for loops. You can check source code in your python directly somewhere at ..\Lib\site-packages\torch\utils\data\dataloader.py