Search code examples
pythonmachine-learningconv-neural-networkartificial-intelligencemedical-imaging

How to Fix Slight Mismatch Between Dimensions of CNN Output Data and the Target?


I am trying to create a version of the UNet CNN which will take in a certain type of MRI image volume as the source and use corresponding MRI image volume as the target.

After quite a bit of trial and error I am still getting a small mismatch between the size of the CNN's output and the dimensions of the target. The CNN output is 208x224x160, but the source/target data are both 210x224x160. This causes a runtime error during the calculation of the loss. What's strange is that the dimension mismatch doesn't occur when I put in randomly generated data, the output has the same dimensions as the input.

What could be causing this error and how should I go about fixing it?

Here is the code:

import nibabel as nib
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

# function using nibabel to load a single volume from the disk
def load_Volume(filepath):
    img = nib.load(filepath)
    data = img.get_fdata()
    return data

def preprocess_mri_data(data):
    # Normalize the data, other pre-processing can be added
    mean = np.mean(data)
    std = np.std(data)
    data = (data - mean) / std
    return data


# Dataset class to use with the data loader. Pairs sources with targets.
class MRISource_Target(Dataset):
    def __init__(self, source_dir, target_dir, transform=None):
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.source_filenames = os.listdir(source_dir)
        self.target_filenames = os.listdir(target_dir)
        self.transform = transform

    def __len__(self):
        return len(self.source_filenames)

    def __getitem__(self, idx):
        source_filepath = os.path.join(self.source_dir, self.source_filenames[idx])
        target_filepath = os.path.join(self.target_dir, self.target_filenames[idx])

        source_data = load_Volume(source_filepath)
        target_data = load_Volume(target_filepath)

        source_data = preprocess_mri_data(source_data)
        target_data = preprocess_mri_data(target_data)

        if self.transform:
            source_data = self.transform(source_data)
            target_data = self.transform(target_data)

        return {'source': source_data, 'target': target_data}

# directories for the training and testing data
train_source_dir = '/content/drive/MyDrive/qsmData/Train/Source'
train_target_dir = '/content/drive/MyDrive/qsmData/Train/Target/'
test_source_dir = '/content/drive/MyDrive/qsmData/Test/Source/'
test_target_dir = '/content/drive/MyDrive/qsmData/Test/Target/'

# create the paired datasets
train_dataset = MRISource_Target(train_source_dir, train_target_dir)
test_dataset = MRISource_Target(test_source_dir, test_target_dir)

# make the datasets iteratable for training
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# visualize an arbitrary slice
def plot_mri_slice(volume, slice_num):
    plt.imshow(volume[:, :, slice_num], cmap='gray')
    plt.axis('off')
    plt.show()

import torch
import torch.nn as nn

# Define the U-Net architecture
class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(input_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        self.middle = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, output_channels, kernel_size=3,padding=1),
            #nn.Tanh()  # Assuming magnetic susceptibility values are in a specific range
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.middle(x1)
        x3 = self.decoder(x2)
        return x3

# Example usage:
batch_size = 1
input_channels = 1  # Number of input channels (MRI phase)
output_channels = 1  # Number of output channels (Magnetic susceptibility)
depth = 64  # Updated depth to match cropped data
height = 64
width = 64

# Create the U-Net model
generator = UNet(input_channels, output_channels)

# Example input data
input_data = torch.randn(batch_size, input_channels, depth, height, width)

# Generate output
output = generator(input_data)

# Print the generated output shape
print("Generated Output Shape:", output.shape)
import nibabel as nib

def get_data_dimensions(filepath):
    img = nib.load(filepath)
    data = img.get_fdata()
    return data.shape

source_filepath = '/content/drive/MyDrive/qsmData/Train/Source/normPhaseSubj1.nii'
target_filepath = '/content/drive/MyDrive/qsmData/Train/Target/cosmos1.nii.gz'


source_dimensions = get_data_dimensions(source_filepath)
target_dimensions = get_data_dimensions(target_filepath)

print("Source data dimensions:", source_dimensions)
print("Target data dimensions:", target_dimensions)


# Define the loss function and optimizer
criterion = nn.MSELoss(reduce=None)
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)

# Move the model to the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
num_epochs = 5
print_interval = 10
for epoch in range(num_epochs):
    generator.train()
    running_loss = 0.0
    for i, batch in enumerate(train_loader, 1):  # Enumerate to track batch index
        source_data = batch['source'].to(device).unsqueeze(1).float()  # Add the channel dimension
        target_data = batch['target'].to(device).unsqueeze(1).float()  # Add the channel dimension

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = generator(source_data)

        print(outputs.shape)

        print("Target shape:", target_data.shape)

        # Compute loss
        loss = criterion(outputs, target_data)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Print average loss for the epoch
        if i % print_interval == 0:
            avg_loss = running_loss / print_interval
            print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i}/{len(train_loader)}], Loss: {avg_loss:.4f}')
            running_loss = 0.0

predictions = []
generator.eval()  # Set the model to evaluation mode
with torch.no_grad():
    for batch in test_loader:
        source_patches = batch['source'].to(device).unsqueeze(1).float()  # Add the channel dimension

        # Forward pass and get the predictions
        outputs = generator(source_patches)

        # Store the predictions in the list
        predictions.append(outputs.cpu().squeeze().numpy())

I tried making a simpler architecture and still got dimension errors, in fact they were even larger. When I wasn't getting dimension errors I would just get out of memory errors. I also have tried verifying the dimensions of the data throughout different stages of the program, and even though the randomly generated data doesn't have a mismatch between input and output, my MRI data still does once it's put through the network.


Solution

  • I was able to fix the dimension errors by applying max pooling after the middle layer and the decoder.
    I am not sure why this works, but now output and input sizes are consistent.
    The prediction results look bad right now, but I'm pretty sure that's because I set the number of channels down to between 2 to 8, and only trained for 1 epoch.
    It will be interesting to see how this architecture works as I apply normal hyper-parameters. I'm just glad that there are no more runtime errors or out of memory issues.