Search code examples
pythontensorflowdeep-learningkeras

Connecting Keras models / replacing input but keeping layers


This questions is similar to Keras replacing input layer.

I have a classifier network and an autoencoder network and I want to use the output of the autoencoder (i.e. encoding + decoding, as a preprocessing step) as the input to the classifier - but after the classifier was already trained on the regular data.

The classification network was built with the functional API like this (based on this example):

clf_input = Input(shape=(28,28,1))
clf_layer = Conv2D(...)(clf_input)
clf_layer = MaxPooling2D(...)(clf_layer)
...
clf_output = Dense(num_classes, activation='softmax')(clf_layer)
model = Model(clf_input, clf_output)
model.compile(...)
model.fit(...)

And the autoencoder like this (based on this example):

ae_input = Input(shape=(28,28,1))
x = Conv2D(...)(ae_input)
x = MaxPooling2D(...)(x)
...
encoded = MaxPooling2D(...)(x)
x = Conv2d(...)(encoded)
x = UpSampling2D(...)(x)
...
decoded = Conv2D(...)(x)
autoencoder = Model(ae_input, decoded)
autoencoder.compile(...)
autoencoder.fit(...)

I can concatenate the two models like this (I still need the original models, hence the copying):

model_copy = keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())
# remove original input layer
model_copy.layers.pop(0)
# set the new input
new_clf_output = model_copy(decoded)
# get the stacked model
stacked_model = Model(ae_input, new_clf_output)
stacked_model.compile(...)

And this works great when all I want to do is apply the model to new test data, but it gives an error on something like this:

for layer in stacked_model.layers:
    print layer.get_config()

where it gets to the end of the autoencoder but then fails with a KeyError at the point where the classifier model gets its input. Also when plotting the model with keras.utils.plot_model I get this:

stacked_model

where you can see the autoencoder layers but then at the end, instead of the individual layers from the classifier model, there is only the complete model in one block.

Is there a way to connect two models such the new stacked model is actually made up of all the individual layers?


Solution

  • Ok, what I could come up with is to really manually go through each layer of the model and reconnect them one by one again like this:

    l = model.layers[1](decoded)  # layer 0 is the input layer, which we're replacing
    for i in range(2, len(model.layers)):
        l = model.layers[i](l)
    stacked_model = Model(ae_input, l)
    stacked_model.compile(...)
    

    while this works and produces the correct plot and no errors, this does not seem like the most elegant solution...

    (btw, the copying of the model actually seems to be unnecessary as I'm not retraining anything.)