I am training a convolutional network using pytorch that works on 3D medical raster images (.nrrd files) to get estimated volume measurements from very noisy ultrasound images.
I have around 200 individual raster images of 30 patients, and have augmented them to over 5000 applying all kind of transforms and noise in all 3 axis (chosen randomly). All the rasters are resized to 128x128x128 before being used.
I am doing 6-fold cross validation, where I make sure that the validation set is composed of entirely different patients from those in the training set. I think this helps see if the model is actually generalizing and is capable of estimating rasters of unseen patients.
Problem is, the model is failing to generalize or learn at all. See the results I get for 2 test runs I have made (10 hours processing each):
The architecture used is just 6 convolutional layers followed by 2 densely connected ones, nothing too fancy. What could be causing this? Could it be I don't have enough data for my model to learn?
I tried lowering the learning rate and raising weight decay, no luck. I haven't tried using other criterions and optimizers (currently using MSE Loss and Adam).
*Edit: Added code:
class RasterNet(nn.Module):
def __init__(self):
super(RasterNet, self).__init__()
self.conv0 = nn.Sequential( # 128x128x128 -> 256x32x32
nn.Conv2d(128, 256, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv1 = nn.Sequential( # 256x32x32 -> 512x16x16
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential( # 512x16x16 -> 1024x8x8
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1024),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv3 = nn.Sequential( # 1024x8x8 -> 2048x4x4
nn.Conv2d(1024, 2048, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(2048),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv4 = nn.Sequential( # 2048x4x4 -> 4096x2x2
nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(4096),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv5 = nn.Sequential( # 4096x2x2 -> 8192x1x1
nn.Conv2d(4096, 8192, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(8192),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.linear = nn.Sequential(
nn.Linear(8192, 4096),
nn.ReLU(),
nn.Linear(4096, 1)
)
def forward(self, base):
base = base.squeeze().float().to(dml)
# View from y axis (Coronal, as this is the clearest view)
base = torch.transpose(base, 2, 1)
x = self.conv0(base)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = x.view(x.size(0), -1)
return self.linear(x)
Ok a few notes which are not an "answer" per se but are too extended for comments:
First, the fact that your training loss converges to a low value, but your validation loss is high, means that your model is overfit to the training distribution. This could mean:
Second, an almost-ubiquitous trick for NN training is to use at-runtime data augmentation. This means that, rather then generating a set of augmented images before training, you instead generate a set of augmenting functions which apply data transformations randomly. This set of functions is used to transform the data batch at each training epoch, such that the model never sees exactly the same data example twice.
Third, this model architecture is relatively simplistic (simpler than AlexNet, the first modern deep CNN.) Far greater performance has been achieved by making much deeper architectures and using residual layers to (see ResNet) to deal with the vanishing gradient problem. I'd be somewhat surprised if you could achieve good performance on this task with this architecture.
It is normal for the validation loss to be higher on average than the training loss. It is possible that your model is learning to some extent but the loss curve is relatively shallow when compared to the (likely overfit) training curve. I suggest also computing epoch-wide validation accuracy and reporting this value across epochs. You should see training accuracy increase, and possibly validation accuracy as well.
Do note that cross-validation is not quite exactly meant to determine whether the model generalizes to unseen patients. That is the purpose of the validation set. Instead, cross-validation ensures that the training - validation performance is valid across multiple data partitions, and isn't simply the result of selecting an "easy" validation set.
Purely for speed/simplicity, I recommend training the model first without cross-validation (i.e. use a single training-testing partition. Once you achieve good performance on the whole dataset, you can retrain with k-fold to ensure the above, but this should make your debug cycles a bit faster.