Search code examples
pythonpytorchiteration

Why next(iter(train_dataloader)) takes long execution time in PyTorch


I am trying to load a local dataset with images (around 225 images in total) using the following code:

# Set the batch size
BATCH_SIZE = 32 

# Create data loaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
  train_dir=train_dir,
  test_dir=test_dir,
  transform=manual_transforms, # use manually created transforms
  batch_size=BATCH_SIZE
)

# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader)) # why it takes so much time? what can 
      I do about it?

My question concerns the last line of the code and the iteration in the train_dataloader which takes long execution time. Why is this the case? I have only 225 images.

Edit:

The code for the dataloader can be found in the following link.

import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pdb

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
  train_dir: str, 
  test_dir: str, 
  transform: transforms.Compose, 
  batch_size: int, 
  num_workers: int=NUM_WORKERS
):
# Use ImageFolder to create dataset(s)
train_data = datasets.ImageFolder(train_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)

# Get class names
class_names = train_data.classes

# Turn images into data loaders
train_dataloader = DataLoader(
  train_data,
  batch_size=batch_size,
  shuffle=True,
  num_workers=num_workers,
  pin_memory=True,
)
test_dataloader = DataLoader(
  test_data,
  batch_size=batch_size,
  shuffle=False, # don't need to shuffle test data
  num_workers=num_workers,
  pin_memory=True,
)

return train_dataloader, test_dataloader, class_names

Solution

  • The main reason the next(iter(train_dataloader) call is slow is due to multiprocessing - or to the pittfalls of multiprocessing. When num_workers > 0, the call to iter(train_dataloader) will fork the main Python process (the current script), which means that any time-consuming code that occurs during import before the call to iter(...), such as any kind of file loading that happens in global scope (!), will cause an extra slow down. That is, extra on top of the process creation time and on top of the serialization and deserialization of data that needs to happen when next(iter(...)) is called.

    You can verify this by adding time.sleep(5) in global scope anywhere before calling next(iter(train_dataloader)). You'll then see that the call will be 5 sec slower than it already was.

    Unfortunately, I don't know how to fix this for the torch DataLoader, apart from either (1) set num_workers=0, or (2) make sure you don't have time-consuming code during the import of the main script, or (3) don't use the torch DataLoader, but use the HuggingFace dataset interfaces.

    Update: There does not seem to be a work-around here. If you have the following code (in the same script):

    dataloader = create_dataloader(...)  # similar to the OPs code
    for x in dataloader:
         ...
    

    or also if you initialized the dataloader in some other module and use something like

    from other_module import dataloader
    a, b = next(iter(dataloader))
    

    then the fork (that is triggered by starting to iterate) will cause re-initialization of the dataloader (and its underlying datasets, reading everything from disk again). So, it appears that it only makes sense to use num_workers=1 (or higher) if data actually needs to be downloaded from remote servers. If all data is already on the localhost, then, as I understand it, it never makes sense to set num_workers=1 (or higher) in this API. (I'm not totally sure here, since I'm not familiar with the underlying torch implementation. Conceivably it could also make sense when the transform method is much slower than the serialization/deserialization part of the code.)