Search code examples
pythonpytorchpytorch-dataloader

PyTorch model training with DataLoader is too slow


I'm training a very small NN using the HAM10000 dataset. For loading the data I'm using the DataLoader that ships with PyTorch:

class CocoDetectionWithFilenames(CocoDetection):
    def __init__(self, root: str, ann_file: str, transform=None):
        super().__init__(root, ann_file, transform)

    def get_filename(self, idx: int) -> str:
        return self.coco.loadImgs(self.ids[idx])[0]["file_name"]


def get_loaders(root: str, ann_file: str) -> tuple[CocoDetection, DataLoader, DataLoader, DataLoader]:
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    dataset = CocoDetectionWithFilenames(
        root=root,
        ann_file=ann_file,
        transform=transform
    )
    train_size = int(0.7 * len(dataset))
    valid_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - valid_size
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])
    num_workers = os.cpu_count()
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=1024
    )
    valid_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=1024
    )
    test_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return dataset, train_loader, valid_loader, test_loader

The thing is, when my training loop runs, the training itself is very fast, but the program spends 95% on the time inbetween epochs - probably loading the data:

def extract_bboxes(targets: list[dict]) -> list[torch.Tensor]:
    bboxes = []

    for target in targets:
        xs, ys, widths, heights = target["bbox"]

        for idx, _ in enumerate(xs):
            x1, y1, width, height = xs[idx], ys[idx], widths[idx], heights[idx]
            # Convert COCO format (x, y, width, height) to (x1, y1, x2, y2)
            x2, y2 = x1 + width, y1 + height

            bboxes.append(torch.IntTensor([x1, y1, x2, y2]))

    return bboxes

num_epochs = 25
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, targets in train_loader_tqdm:
        images = images.to(device)
        bboxes = extract_bboxes(targets)
        bboxes = torch.stack(bboxes).to(device)

        optimizer.zero_grad(set_to_none=True)

        outputs = model(images)
        loss = criterion(outputs, bboxes)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_train_loss = running_loss / len(train_loader)

    train_losses.append(epoch_train_loss)
    print(f"Epoch {epoch + 1}, Loss: {epoch_train_loss}")
    model.eval()

As you can see, the training loop code is quite simple, nothing weird happening there.


Solution

  • Try to reduce num_workers and prefetch_factor. It may spend all the time fetching that 1024 batches using all threads.