Search code examples
tensorflowkerasconv-neural-networkplaceholdersess.run

Issue with feeding value into placeholder tensor for sess.run()


I want to get the value of an intermediate tensor in a convolutional neural network for a specific input. I know how to do this in keras and even though I have trained a model using keras, I'm going to move towards constructing and training the model using only tensorflow. Therefore, I want to move away from something like K.function(input_layer, output_layer) which is fairly simple, and instead use tensorflow. I believe I should use placeholder values, like the following approach:

with tf.compat.v1.Session(graph=tf.Graph()) as sess:
    loaded_model = tf.keras.models.load_model(filepath)
    graph = tf.compat.v1.get_default_graph()  
    images = tf.compat.v1.placeholder(tf.float32, shape=(None, 28, 28, 1)) # To specify input at MNIST images
    output_tensor = graph.get_tensor_by_name(tensor_name) # tensor_name is 'dense_1/MatMul:0'
    output = sess.run([output_tensor], feed_dict={images: x_test[0:1]}) # x_test[0:1] is of shape (1, 28, 28, 1)
    print(output)

However, I get the following error message for the sess.run() line: Invalid argument: You must feed a value for placeholder tensor 'conv2d_2_input' with dtype float and shape [?,28,28,1]. I am unsure why I get this message because the image used for feed_dict is of type float and is what I believe to be the correct shape. Any help would be suggested.


Solution

  • You must use the input tensor from the Keras model, not make your own new placeholder, which would be disconnected from the rest of the model:

    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
        # Load model
        loaded_model = tf.keras.models.load_model(filepath)
        # Take model input tensor
        images = loaded_model.input
        # Take output of the second layer (index 1)
        output_tensor = loaded_model.layers[1].output
        # Evaluate
        output = sess.run(output_tensor, feed_dict={images: x_test[0:1]})
        print(output)