Search code examples
pythontensorflowkerasdeep-learningfeature-extraction

How to get output of intermediate Keras layers in batches?


I am not sure how to get output of an intermediate layer in Keras. I have read the other questions on stackoverflow but they seem to be functions with a single sample as input. I want to get output features(at intermediate layer) in batches as well. Here is my model:

model = Sequential()
model.add(ResNet50(include_top = False, pooling = RESNET50_POOLING_AVERAGE, weights = resnet_weights_path)) #None
model.add(Dense(784, activation = 'relu'))
model.add(Dense(NUM_CLASSES, activation = DENSE_LAYER_ACTIVATION))
model.layers[0].trainable = True

After training the model, in my code I want to get the output after the first dense layer (784 dimensional). Is this the right way to do it?

pred = model.layers[1].predict_generator(data_generator, steps = len(data_generator), verbose = 1)

I am new to Keras so I am a little unsure. Do I need to compile the model again after training?


Solution

  • No, you don't need to compile again after training.

    Based on your Sequential model.

    Layer 0 :: model.add(ResNet50(include_top = False, pooling = RESNET50_POOLING_AVERAGE, weights = resnet_weights_path)) #None
    Layer 1 :: model.add(Dense(784, activation = 'relu'))
    Layer 2 :: model.add(Dense(NUM_CLASSES, activation = DENSE_LAYER_ACTIVATION))
    

    Accessing the layers, may differ if used Functional API approach.

    Using Tensorflow 2.1.0, you could try this approach when you want to access intermediate outputs.

    model_dense_784 = Model(inputs=model.input, outputs = model.layers[1].output)
    
    pred_dense_784 = model_dense_784.predict(train_data_gen, steps = 1) # predict_generator is deprecated
    
    print(pred_dense_784.shape) # Use this to check Output Shape
    

    It is highly advisable to use the model.predict() method, rather than model.predict_generator() as it is already deprecated.
    You could also use shape() method to check whether the output generated is the same as indicated on the model.summary().