Search code examples
deep-learningmodelconv-neural-network

How to send input data to my multi-channel deep learning model?


I have some inputs and want to implement a multi-channel output deep model. So the input will be a list of 5 arrays in shape (1300, 320, 320) in which 1300 is total number of images in that array and each image would be 320*320. The model I have in mind is pretty similar to the image below:

enter image description here

I have written the below code, but not sure if this is the correct way to do or not? and another problem is that I do not know exactly is how to feed inputs and write the train part of code.

def MyMultiChannelNet(input_channels):
    channels = []
    models = []

    for i in range(input_channels):
        input_layer = Input(shape=(320, 320, 1))
        channels.append(input_layer)

        x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)
        x = BatchNormalization()(x)
        x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = MaxPooling2D((2, 2), strides=(2, 2))(x)

        x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = MaxPooling2D((2, 2), strides=(2, 2))(x)

        x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = MaxPooling2D((2, 2), strides=(2, 2))(x)

        models.append(Model(inputs=input_layer, outputs=x))

    merged = concatenate([model.output for model in models])

    x = Flatten()(merged)
    x = Dense(4096, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(4096, activation='relu')(x)
    x = BatchNormalization()(x)

    output_layer = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=channels, outputs=output_layer)

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

Solution

  • What you have implemented will create 5 independent networks, each channel will be processed separately, without looking at others. And then finally, there will be shallow (2 layer) MLP that is the only place where all this data is mixed. This is a very no-standard model, and typically we would just apply one network to the whole thing. Convolutions are already capable of processing data that is Width x Height x Channels so you shouldn't have to do any of this, but instead:

    
    input_layer = Input(shape=(320, 320, input_channels))
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)
    x = BatchNormalization()(x)
    ...
    

    In other words just stack your data, so that it is [Batch, Width, Height, Channels] and apply a regular convnet stack to it.

    If you want to have this separation for some good modeling reason, you can also just simplify your code, you don't need all the small models

    def MyMultiChannelNet(input_channels):
        outputs = []
    
        input_layer = Input(shape=(320, 320, input_channels))
    
        for i in range(input_channels):
            channels.append(input_layer)
            input_channel = input_layer[:, :, input_channels:input_channels+1]
            x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_channel)
            x = BatchNormalization()(x)
            x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
            x = BatchNormalization()(x)
            x = MaxPooling2D((2, 2), strides=(2, 2))(x)
    
            x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
            x = BatchNormalization()(x)
            x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
            x = BatchNormalization()(x)
            x = MaxPooling2D((2, 2), strides=(2, 2))(x)
    
            x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
            x = BatchNormalization()(x)
            x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
            x = BatchNormalization()(x)
            x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
            x = BatchNormalization()(x)
            x = MaxPooling2D((2, 2), strides=(2, 2))(x)
    
            outputs.append(x)
    
        merged = concatenate(outputs)
    
        x = Flatten()(merged)
        x = Dense(4096, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dense(4096, activation='relu')(x)
        x = BatchNormalization()(x)
    
        output_layer = Dense(1, activation='sigmoid')(x)
    
        model = Model(inputs=input_layer, outputs=output_layer)
    
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
        return model