Search code examples
pythontensorflowkerasmatrix-multiplication

Keras GradCam implementation that can process batches of images instead of a single image at a time


I'm following the GradCam example from the Keras documentation https://keras.io/examples/vision/grad_cam/ and want to modify it so that it can process a batch of images instead of only a single image at a time.

I was already able to accomplish this but had to use a call to tf.map_fn which I would like to get rid of in the hopes of a performance improvement.

My progress so far (whole code at Google Coolab):

#https://keras.io/examples/vision/grad_cam/
def make_gradcam_heatmap(grad_model, images, pred_index=None):
    images = tf.cast(images, tf.float32)

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        tape.watch(images)
        last_conv_layer_output, preds = grad_model(images)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)
    assert grads is not None, "GradientTape returned gradients=None"

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(1, 2))

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    def single_image(index):
        return last_conv_layer_output[index] @ pooled_grads[index][tf.newaxis, ..., tf.newaxis]

    heatmaps = tf.map_fn(single_image, tf.range(tf.shape(grads)[0]), dtype=tf.float32)

    # normalize the whole batch to [0, 1]
    heatmaps -= tf.math.reduce_min(heatmaps, axis=(0,1,2,3))
    heatmaps /= tf.math.reduce_max(heatmaps, axis=(0,1,2,3))

    return heatmaps

Is there any way to rewrite this code in such a way that it doesn't use tf.map_fn?

    def single_image(index):
        return last_conv_layer_output[index] @ pooled_grads[index][tf.newaxis, ..., tf.newaxis]

heatmaps = tf.map_fn(single_image, tf.range(tf.shape(grads)[0]), dtype=tf.float32)

Solution

  • You can use matmul,

    heatmaps = tf.matmul(last_conv_layer_output, pooled_grads[:,tf.newaxis, :, tf.newaxis])