Search code examples
pythontensorflowmachine-learningkerasconv-neural-network

Integrating numerical/physical data for CNN image classification


I am attempting to use a CNN to classify medical images in python using keras. These medical images also include textual information such as age and gender that can influence the decision of the model. How can I train a CNN that can train using both the images and the real world information so that it can make classifications base of both?


Solution

  • There are a couple of possibilities that I can think of off the type of my head, but the simplest is to extract some features from the medical images with a CNN, then flatten the result of the CNN, and concatenate the non-image data. Here is an idea supposing you have 512x512 images and 10 classes. This is the functional API which allows you to have multiple inputs.

    import tensorflow as tf
    import numpy as np
    
    num_classes = 10
    
    H,W = 512, 512
    # Define inputs with their shapes
    imgs = tf.keras.Input((H,W,3), dtype = tf.float32)
    genders = tf.keras.Input(1, dtype = tf.float32)
    ages = tf.keras.Input(1, dtype = tf.float32)
    
    # Extract image features
    features = tf.keras.layers.Conv2D(64, 4, strides = 4, activation = 'relu')(imgs)
    features = tf.keras.layers.MaxPooling2D()(features)
    features = tf.keras.layers.Conv2D(128,3, strides = 2, activation = 'relu')(features)
    features = tf.keras.layers.MaxPooling2D()(features)
    features = tf.keras.layers.Conv2D(256, 3, strides = 2, activation = 'relu')(features)
    features = tf.keras.layers.Conv2D(512, 3, strides = 2, activation = 'relu')(features)
    
    # #Flatten output
    flat_features = tf.keras.layers.Flatten()(features)
    
    #Concatenate gender and age
    flat_features = tf.concat([flat_features, genders, ages], -1)
    
    # Downsample
    xx = tf.keras.layers.Dense(2048, activation = 'relu')(flat_features)
    xx = tf.keras.layers.Dense(1024, activation = 'relu')(xx)
    xx = tf.keras.layers.Dense(512, activation = 'relu')(xx)
    
    #Calculate probabilities for each class
    logits = tf.keras.layers.Dense(num_classes)(xx)
    probs = tf.keras.layers.Softmax()(logits)
    
    model = tf.keras.Model(inputs = [imgs, genders, ages], outputs = probs)
    
    model.summary()
    

    This architecture is not especially standard, and you might want to make the decoder deeper and/or decrease the number of parameters in the CNN encoder.