This is the code
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
# creating a custom class for our dataset, which inherits from Dataset.
class WineDataset(Dataset):
# this function is used for data loading
def __init__(self):
# data loading
xy = np.loadtxt('./wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
self.x = torch.from_numpy(xy[:, 1:]) # the first column is the output label
self.y = torch.from_numpy(xy[:, [0]]) # n_samples, 1
self.n_samples = xy.shape[0]
# this function allows indexing in our dataset
def __getitem__(self, index):
return self.x[index], self.y[index] # the function returns a tuple.
# this allows us to call len on our dataset.
def __len__(self):
return self.n_samples
dataset = WineDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)
dataiter = iter(dataloader)
data = next(dataiter)
features, labels = data
print(features, labels)
My question is that since, we can already call the enumerate method directly on the dataloader, does it mean that the dataloader object is an iterable ? If this is true, then calling iter(dataloader) would be the same as creating an iterator object from an iterator object ?
I'm a bit confused about this please help me out.
I need to know what the enumerate method is doing behind the scenes when dataloader is passed as an argument. Also need to know what iter(dataloader) is doing.
Iterable is something which implements __iter__
method. Iterator is something which implements __next__
method. Both iter()
and enumerate()
call the __iter___
method of the class. for example
class A: # this is an iterable
def __iter__(self):
print ('iter called at A')
return B()
class B: # this is an iterator
def __next__(self):
print( 'next called at B')
return 1
Note, any object of class B
is an iterator because it implements __next__
but its not an iterable becuase it doesn't have __iter__
method. Similarly, any object of class A
is an iterable but not an iterator.
Run it,
a = A()
create an iterator
b = iter(a)
print(f'{type(b)=}')
"""
iter called at A
type(b)=<class '__main__.B'>
"""
calling next()
on iterator b
next(b)
"""
next called at B
1
"""
can't call next()
on iterable a
next(a)
"""
TypeError: 'A' object is not an iterator
"""
We can do a for
loop on a
for i in a:
print(i)
break
"""
iter called at A
next called at B
1
"""
Can't do a for
loop on b
for i in b:
print(i)
break
"""
TypeError: 'B' object is not iterable
"""
Now, call enumerate
c = enumerate(a)
print(f'{type(c)=}')
"""
iter called at A
type(c)=<class 'enumerate'>
"""
Can do a for
loop on c as well as call next()
next(c)
"""
next called at B
(0, 1)
"""
for i in c:
print(i)
break
"""
next called at B
(1, 1)
"""
So the enumerate
class is both an iterator and an iterable because it has both __iter__
and __next__
methods. You can check this by calling dir(c)
.
When we call enumerate
on a Dataloader
, its __iter__
method is called. Looking at the signature of __iter__
function in pytorch source code:
class DataLoader(Generic[T_co]):
.
.
def __iter__(self) -> '_BaseDataLoaderIter':
This _BaseDataLoaderIter
class implements both __iter__
and __next__
, so its both an iterable and an iterator.
class _BaseDataLoaderIter(object):
.
.
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def __next__(self) -> Any:
.
.
return data
So you can call both enumerate()
and iter()
on Dataloader
and can even do for
loops. You can check source code in your python directly somewhere at
..\Lib\site-packages\torch\utils\data\dataloader.py