Search code examples
pythonnumpykerasfeature-extractionkeras-layer

Display extracted feature vector from trained layer of the model as an image


I am using Transfer learning for recognizing objects. I used trained VGG16 model as the base model and added my classifier on top of it using Keras. I then trained the model on my data, the model works well. I want to see the feature generated by the intermediate layers of the model for the given data. I used the following code for this purpose:

def ModeloutputAtthisLayer(model, layernme, imgnme, width, height):

    layer_name = layernme
    intermediate_layer_model = Model(inputs=model.input,
                                     outputs=model.get_layer(layer_name).output)
    img = image.load_img(imgnme, target_size=(width, height))
    imageArray = image.img_to_array(img)
    image_batch = np.expand_dims(imageArray, axis=0)
    processed_image = preprocess_input(image_batch.copy())
    intermediate_output = intermediate_layer_model.predict(processed_image)
    print("outshape of ", layernme, "is ", intermediate_output.shape)

In the code, I used np.expand_dims to add one extra dimension for the batch as the input matrix to the network should be of the form (batchsize, height, width, channels). This code works fine. The shape of the feature vector is 1, 224, 224, 64.

Now I wish to display this as image, for this I understand there is an additional dimension added as batch so I should remove it. Following this I used the following lines of the code:

imge = np.squeeze(intermediate_output, axis=0)
plt.imshow(imge)

However it throws an error:

"Invalid dimensions for image data"

I wonder how can I display the extracted feature vector as an image. Any suggestion please.


Solution

  • Your feature shape is (1,224,224,64), you cannot directly plot a 64 channel image. What you can do is plot the individual channels independently like following

    imge = np.squeeze(intermediate_output, axis=0)
    filters = imge.shape[2]
    plt.figure(1, figsize=(32, 32))   # plot image of size (32x32)
    n_columns = 8
    n_rows = math.ceil(filters / n_columns) + 1
    for i in range(filters):
        plt.subplot(n_rows, n_columns, i+1)
        plt.title('Filter ' + str(i))
        plt.imshow(imge[:,:,i], interpolation="nearest", cmap="gray")
    

    This will plot 64 images in 8 rows and 8 columns.