Search code examples
pythontensorflowpytorchconv-neural-network

CNN taking input images with different number of channels


I am training a CNN to segment input images. I have a set of images with 3 channels (RGB) and some with 4 channels (RGB and infrared). Ideally, I'd like to train a single network on both, using all available channels.

Is there as strategy that accepts both images as input, takes advantage of the additional channel when its there and otherwise ignores it?


Solution

  • By default, a neural network must know the shape of its inputs and outputs. Recently, PyTorch has added dynamic batching, which allows you to send tensors to your model with dynamic batch size and other constant components, including the number of channels ([batch, ch, w, h]). Therefore, to solve your problem, I propose two methods:

    Normalize all images to 4D: Add the infrared layer as the fourth component to your RGB image, and if there is no RGB image, create a black image of the same shape as your images to add as the fourth component of your images. This will not affect the reliability of your network and will normalize all your images to 4D.

    Create two branches in your model: At the input of your model, create two branches, one for taking 3D images as input and another for 4D images. Perform operations on the branches, and in the middle of your network, normalize the output of each branch so that you can merge them together and continue convolution or upsampling operations with a known shape. (Look at the UNet architecture in its decoding phase to see how it merges blocks of tensors.)

    Here is a more detailed explanation of each method:

    Method 1: Normalize all images to 4D

    This method is the simplest and most straightforward. It simply adds a fourth channel to all images, regardless of whether they are originally RGB or infrared. This ensures that all images have the same shape, which is required by PyTorch.

    The advantage of this method is that it is easy to implement and does not require any additional changes to your model. However, it may not be optimal if your images are already in 4D. In this case, the extra channel may not add any useful information and may only serve to increase the size of your model.

    Method 2: Create two branches in your model

    This method is more complex, but it allows you to process RGB and infrared images separately. This may be beneficial if your images have different properties, such as different resolutions or contrast levels.

    To implement this method, you will need to create two branches in your model. One branch will take RGB images as input, and the other branch will take infrared images as input. You can then perform the same operations on each branch, or you can customize the operations for each type of image.

    The advantage of this method is that it allows you to tailor your model to the specific properties of your images. However, it is more complex to implement and may require more training data.

    Which method is best for you will depend on your specific needs and requirements.