Search code examples
pythonpytorchiteratordataloader

What does next() and iter() do in PyTorch's DataLoader()


I have the following code:

import torch
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader

# Load dataset
df = pd.read_csv(r'../iris.csv')

# Extract features and target
data = df.drop('target',axis=1).values
labels = df['target'].values

# Create tensor dataset
iris = TensorDataset(torch.FloatTensor(data),torch.LongTensor(labels))

# Create random batches
iris_loader = DataLoader(iris, batch_size=105, shuffle=True)

next(iter(iris_loader))

What does next() and iter() do in the above code? I have went through PyTorch's documentation and still can't quite understand what is next() and iter() doing here. Can anyone help in explaining this? Many thanks in advance.


Solution

  • These are built-in functions of python, they are used for working with iterables.

    Basically iter() calls the __iter__() method on the iris_loader which returns an iterator. next() then calls the __next__() method on that iterator to get the first iteration. Running next() again will get the second item of the iterator, etc.

    This logic often happens 'behind the scenes', for example when running a for loop. It calls the __iter__() method on the iterable, and then calls __next__() on the returned iterator until it reaches the end of the iterator. It then raises a stopIteration and the loop stops.

    Please see the documentation for further details and some nuances: https://docs.python.org/3/library/functions.html#iter