Search code examples
pythonpytorchconv-neural-networktraining-databatchsize

expected batch_size() is not the same as target batch_size() pytorch


I wrote a code to segment a satellite-picture into seven regions (city, forest, water, ...). The problem is that when I execute the script I get exact the following error:

Traceback (most recent call last): File "/Users/.../pytorch_test.py", line 215, in model = train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, num_epochs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/.../pytorch_test.py", line 179, in train_model train_loss = train(model, train_dataloader, loss_fn, optimizer, device) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/.../pytorch_test.py", line 134, in train loss = loss_fn(outputs, labels) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/loss.py", line 1174, in forward return F.cross_entropy(input, target, weight=self.weight, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/functional.py", line 3029, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: Expected input batch_size (3) to match target batch_size (9).

The problem is that when I change the batch_size I get every time tripple the target batch_size, but I cannot find the bug. I felt I searched the whole internet, but found nothing. I hope some of you can help me!

Thats my code:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from PIL import ImageEnhance

class SatelliteDataset(Dataset):
    def __init__(self, image_folder, label_folder, transform=None):
        self.image_folder = image_folder
        self.label_folder = label_folder
        self.transform = transform

        self.image_paths = sorted([os.path.join(image_folder, filename) for filename in os.listdir(image_folder)])
        self.label_paths = sorted([os.path.join(label_folder, filename) for filename in os.listdir(label_folder)])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label_path = self.label_paths[idx]

        image = Image.open(image_path)
        label = Image.open(label_path)

        image = self.adjust_brightness(image, brightness_factor=1.8)

        # Resize images to 512x512
        image = image.resize((512, 512), Image.BILINEAR)
        label = label.resize((512, 512), Image.NEAREST)

        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

    def adjust_brightness(self, image, brightness_factor=1.0):
        enhancer = ImageEnhance.Brightness(image)
        enhanced_image = enhancer.enhance(brightness_factor)
        return enhanced_image


# Define paths to the folders containing satellite images and corresponding labels
image_folder = "/Users/.../train_data/images"
label_folder = "/Users/.../train_data/masks"

# Define transformations for normalization and scaling
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Create an instance of the SatelliteDataset
dataset = SatelliteDataset(image_folder, label_folder, transform=transform)

# Create DataLoader for training
def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.stack(images)
    labels = torch.stack(labels)
    return images, labels

batch_size = 3
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# cnn
class CNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.Conv2d(512, out_channels=7, kernel_size=1, stride=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.conv3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)

        x = self.conv4(x)
        x = self.relu4(x)

        x = self.conv5(x)

        x = torch.softmax(x, dim=1)  # add Softmax layer
        return x


# Training loop
def train(model, dataloader, loss_fn, optimizer, device):
    model.train()  # Set the model to training mode

    running_loss = 0.0

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)

        # Reshape labels dimensions
        labels = labels.view(-1, 512, 512)
        labels = labels.long()

        # Calculate loss
        loss = loss_fn(outputs, labels)

        # Backward pass and weight update
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)

    return epoch_loss


# Validation
def evaluate(model, dataloader, loss_fn, device):
    model.eval()  # Set the model to evaluation mode

    running_loss = 0.0

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Calculate loss
            loss = loss_fn(outputs, labels)

            running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


# Perform training
def train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, num_epochs):
    best_val_loss = float('inf')
    best_model_weights = None
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Training step
        train_loss = train(model, train_dataloader, loss_fn, optimizer, device)
        print(f"Train Loss: {train_loss}")

        # Validation step
        val_loss = evaluate(model, val_dataloader, loss_fn, device)
        print(f"Val Loss: {val_loss}")

        # Check for improvement in validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_weights = model.state_dict()

    # Return the best model
    model.load_state_dict(best_model_weights)
    return model



# Example usage of the model
model = CNN(in_channels=3, out_channels=7)

# Select device (CPU)
device = torch.device("cpu")

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

Number of epochs (adjust as needed)
num_epochs = 10

# Create a separate DataLoader for validation with batch_size=3
val_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Perform training
model = train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, num_epochs)

By the way I am programming on MacOS and with VSC

I tried: I copied all the error messages into google and read almost every article on that problem. Nonetheless I could not resolve this. I tried to change the batch sizes (I have 804 pictures to train the model with) and changed the data-loader. I also asked ChatGPT, although it explained the problem very good and gave a couple of ideas to solve this, it helped not.


Solution

  • You have a line labels = labels.view(-1, 512, 512) that is changing the dimensions of your labels. From your dataloader, it seems that your labels are images. If it had 3 channels (RGB), then this dimension-changing line would basically triple your batch size.