Search code examples
pythontensorflowkerasencodingautoencoder

Keras AE with split decoder and encoder - But with multiple inputs


I'm trying to train an auto-encoder in keras. In the end I would like to have a separate encoder and decoder models. I can do this for an ordinary AE like here:https://blog.keras.io/building-autoencoders-in-keras.html

However, I would like to train a conditional variant of the model where I pass conditional information to the encoder and the decoder. (https://www.vadimborisov.com/conditional-variational-autoencoder-cvae.html)

I can create the encoder and decoder fine:

# create the encoder
xIn = Input(shape=(100,), name="data_in")
conditional = Input(shape=(10, ), name='conditional')

modelInput = concatenate([xIn,conditional])
x = Dense(25,activation=activation)(modelInput)
xlatent = Dense(5,activation=activation)(x)

# create the encoder
cencoder = Model(inputs=[xIn,conditional],outputs=xlatent, name = "Encoder")
cencoder.summary()


latentState = Input(shape=(5,),name="latentInput")
conditional = Input(shape=(10,),name="conditional")

decoderInput = concatenate([conditional,latentState])
x = Dense(25,activation=activation)(decoderInput)
out = Dense(5,activation=activation)(x)

# create a decoder
cdecoder = Model(inputs=[xIn,conditional],outputs=out)
cdecoder.summary()

But to now create the autoencoder I need to do something like:

encoded = encoder(input)
out = decoder(encoded)
AE = Model(encoded,out)

How do I do something like this:

encoded = encoder([input,conditional])
out = decoder([encoded,conditional])
AE = Model(encoded,out)

Any way I try it, it gives me a graph disconnect error.

Thanks


Solution

  • Considering that the conditionals are the same for both models

    Do this:

    encoderInput = Input(shape=(100,), name="auto_data_in")
    conditionalInput = Input(shape=(10, ), name='auto_conditional')
    
    encoderOut = cencoder([encoderInput, conditionalInput])
    decoderOut = cdecoder([encoderOut, conditionalInput])
    
    AE = Model([encoderInput, conditionalInput], decoderOut)