Search code examples

How can I access the next step data using DataLoader in PyTorch?

I am using a code that trains neural networks. The code uses the DataLoader of PyTorch to load the data for every iteration. The code looks as follows

for step, data in enumerate(dataloader, 0):
      output = neuralnetwork_model(data)

Here the step is an integer that gives values 0, 1, 2, 3, ....... and data gives a batch of samples at each step. The code passes corresponding batches to the neural network at each step.

I need to just access the data of step n+1 at step n. I need something like this

for step, data in enumerate(dataloader, 0):
      output = neuralnetwork_model(data)
      access = data_of_next_step

How can I achieve this?


  • It seems to be handier to perform such manipulation at the iteration level rather than having to change the data loaders implementation. Looking at Iterate over n successive elements with overlap you can achieve this using itertools.tee:

    def pairwise(iterable):
        "s -> (s0,s1), (s1,s2), (s2, s3), ..."
        a, b = tee(iterable)
        next(b, None)
        return zip(a, b)

    Therefore you simply have to iterate over your wrapped data loader with:

    >>> for batch1, batch2 pairwise(dataloader)
    ...     # batch1 is current batch
    ...     # batch2 is batch of following step