Search code examples
tensorflowkerasclassification

Keras Grad-CAM layer definition


So I am training a Keras model that has a base model of VGG16. As seen below.

Model summary:

Model summary

I am also generating a Grad-CAM heatmap using the visualized layer as an input. because i cannot choose one "inside" of the VGG16 layer. But now i am getting a really rough/ low resulution heatmap.

Heatmap:

Heatmap

Is there a way to get to the layers that are inside the VGG16 layer to get a better heatmap? This way I can choose the last convolutional layer of the VGG16 model.

Tanks already

def get_img_array(img_path, size):
    # `img` is a PIL image of size 299x299
    img = tf.keras.utils.load_img(img_path, target_size=size)
    # `array` is a float32 Numpy array of shape (299, 299, 3)
    array = tf.keras.utils.img_to_array(img)
    # We add a dimension to transform our array into a "batch"
    # of size (1, 299, 299, 3)
    array = np.expand_dims(array, axis=0)
    return array

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = tf.keras.models.Model(
    model.inputs, \[model.get_layer(last_conv_layer_name).output, model.output\]
    )

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]
    
    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)
    
    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    
    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

I tried making a custom model so i have all the layers defined by myself.This did work, but as soon as i use a base model like VGG16 is stops working


Solution

  • You can use the .call() method to get the layers of the VGG16 model:

    inp = tf.keras.Input((299,299,3))
    x = tf.keras.applications.vgg16.VGG16(False, input_shape=(299,299,3)).call(inp)
    x = tf.keras.layers.Conv2D(32,3)(x)
    x = tf.keras.layers.Dense(32)(x)
    x = tf.keras.layers.Dense(2)(x)
    model = tf.keras.models.Model(inp, x)
    model.summary()
    

    Then your model summary looks like this and you can access any layer of VGG16

    Model: "model"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     input_1 (InputLayer)        [(None, 299, 299, 3)]     0         
                                                                     
     block1_conv1 (Conv2D)       (None, 299, 299, 64)      1792      
                                                                     
     block1_conv2 (Conv2D)       (None, 299, 299, 64)      36928     
                                                                     
     block1_pool (MaxPooling2D)  (None, 149, 149, 64)      0         
                                                                     
     block2_conv1 (Conv2D)       (None, 149, 149, 128)     73856     
                                                                     
     block2_conv2 (Conv2D)       (None, 149, 149, 128)     147584    
                                                                     
     block2_pool (MaxPooling2D)  (None, 74, 74, 128)       0         
                                                                     
     block3_conv1 (Conv2D)       (None, 74, 74, 256)       295168    
                                                                     
     block3_conv2 (Conv2D)       (None, 74, 74, 256)       590080    
                                                                     
     block3_conv3 (Conv2D)       (None, 74, 74, 256)       590080    
                                                                     
     block3_pool (MaxPooling2D)  (None, 37, 37, 256)       0         
                                                                     
     block4_conv1 (Conv2D)       (None, 37, 37, 512)       1180160   
                                                                     
     block4_conv2 (Conv2D)       (None, 37, 37, 512)       2359808   
                                                                     
     block4_conv3 (Conv2D)       (None, 37, 37, 512)       2359808   
                                                                     
     block4_pool (MaxPooling2D)  (None, 18, 18, 512)       0         
                                                                     
     block5_conv1 (Conv2D)       (None, 18, 18, 512)       2359808   
                                                                     
     block5_conv2 (Conv2D)       (None, 18, 18, 512)       2359808   
                                                                     
     block5_conv3 (Conv2D)       (None, 18, 18, 512)       2359808   
                                                                     
     block5_pool (MaxPooling2D)  (None, 9, 9, 512)         0         
                                                                     
     conv2d (Conv2D)             (None, 7, 7, 32)          147488    
                                                                     
     dense (Dense)               (None, 7, 7, 32)          1056      
                                                                     
     dense_1 (Dense)             (None, 7, 7, 2)           66        
                                                                     
    =================================================================
    Total params: 14863298 (56.70 MB)
    Trainable params: 14863298 (56.70 MB)
    Non-trainable params: 0 (0.00 Byte)
    _________________________________________________________________