I am working with Medical Images, where I have 130 Patient Volumes, each volume consists of N
number of DICOM Images/slices.
The problem is that between the volumes the the number of slices N
, varies.
Majority, 50% of volumes have 20 Slices, rest varies by 3 or 4 slices, some even more than 10 slices (so much so that interpolation to make number of slices equal between volumes is not possible)
I am able to use Conv3d for volumes where the depth N
(number of slices) is same between volumes, but I have to make use of entire data set for the classification task. So how do I incorporate entire dataset and feed it to my network model ?
If I understand your question, you have 130 3-dimensional images, which you need to feed into a 3D ConvNet. I'll assume your batches, if N was the same for all of your data, would be tensors of shape (batch_size, channels, N, H, W), and your problem is that your N varies between different data samples.
So there's two problems. First, there's the problem of your model needing to handle data with different values of N. Second, there's the more implementation-related problem of batching data of different lengths.
Both problems come up in video classification models. For the first, I don't think there's a way of getting around having to interpolate SOMEWHERE in your model (unless you're willing to pad/cut/sample) -- if you're doing any kind of classification task, you pretty much need a constant-sized layer at your classification head. However, the interpolation doesn't have happen right at the beginning. For example, if for an input tensor of size (batch, 3, 20, 256, 256), your network conv-pools down to (batch, 1024, 4, 1, 1), then you can perform an adaptive pool (e.g. https://pytorch.org/docs/stable/nn.html#torch.nn.AdaptiveAvgPool3d) right before the output to downsample everything larger to that size before prediction.
The other option is padding and/or truncating and/or resampling the images so that all of your data is the same length. For videos, sometimes people pad by looping the frames, or you could pad with zeros. What's valid depends on whether your length axis represents time, or something else.
For the second problem, batching: If you're familiar with pytorch's dataloader/dataset pipeline, you'll need to write a custom collate_fn which takes a list of outputs of your dataset object and stacks them together into a batch tensor. In this function, you can decide whether to pad or truncate or whatever, so that you end up with a tensor of the correct shape. Different batches can then have different values of N. A simple example of implementing this pipeline is here: https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/data_loader.py
Something else that might help with batching is putting your data into buckets depending on their N dimension. That way, you might be able to avoid lots of unnecessary padding.