Search code examples
machine-learningkerasautoencoder

How to connect separately created encoder and decoder together?


I want to create AutoEncoder and play with it. I have created encoder and decoder. With summary() method I checked those are fine and symmetric. Next step I wanted to connect encoder to decoder and receive final model to train it. I have tried to connect both together but final model summary shows that something is wrong. It do not show all layers of model (some parts of decoder are missing). You can see buld_model() function. Here I tried to connect both together. Thanks!

enter image description here

class AutoEncoder(object):
    def __init__(self,
                 shape,
                 encoder_depth,
                 decoder_depth,
                 encoder_filters,
                 decoder_filters,
                 encoder_strides,
                 decoder_strides,
                 latent_space,
                ):

        self.shape = shape

        #Check provided filters
        assert(encoder_depth == len(encoder_filters), "Counts of filters shoud match find depth of Encoder")
        assert(decoder_depth == len(decoder_filters), "Counts of filters shoud match find depth of Decoder")

        #Check provided strides
        assert(encoder_depth == len(encoder_strides), "Counts of strides shoud match find depth of Encoder")
        assert(decoder_depth == len(decoder_strides), "Counts of strides shoud match find depth of Decoder")

        #Deepth and latent space
        self.encoder_depth = encoder_depth
        self.decoder_depth = decoder_depth
        self.latent_space = latent_space

        #Filters
        self.encoder_filters = encoder_filters
        self.decoder_filters = decoder_filters

        #Strides
        self.encoder_strides = encoder_strides
        self.decoder_strides = decoder_strides

        self.buld_model()

    def build_encoder(self):
        input_x = Input(shape=self.shape, name="encoder_input")
        x = input_x

        for i in range(self.encoder_depth):
            x = Conv2D(self.encoder_filters[i],
                       kernel_size = 3,
                       strides = self.encoder_strides[i],
                       padding="same",
                       name = "encoder_conv_" + str(i))(x)
            x = LeakyReLU()(x)

        self.shape_before_flat = K.int_shape(x)[1:]
        x = Flatten()(x)
        encoder_output = Dense(self.latent_space, name="Encoder_output")(x)

        self.encoder_output = encoder_output
        self.encoder_input = input_x

        self.encoder = tf.keras.Model(self.encoder_input , self.encoder_output)

    def build_decoder(self):
        decoder_input = Input(shape = (self.latent_space,), name="decoder_input")

        x = Dense(np.prod(self.shape_before_flat))(decoder_input)
        x = Reshape(self.shape_before_flat)(x)

        for i in range(self.decoder_depth):
            x = Conv2DTranspose(self.decoder_filters[i],
                                kernel_size = 3,
                                strides = self.decoder_strides[i],
                                padding="same",
                                name = "decoder_conv_t_" + str(i))(x)
            if i < self.decoder_depth - 1:
                x = LeakyReLU()(x)
            else:
                x = Activation("sigmoid")(x)

        decoder_output = x        
        self.decoder = tf.keras.Model(decoder_input , decoder_output)

    def buld_model(self):
        self.build_encoder()
        self.build_decoder()

        model_input = self.encoder_input
        model_output = self.decoder(self.encoder_output)

        self.model = tf.keras.Model(model_input,model_output)
        self.model.compile(optimizer="adam",loss = loss )

def loss(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred),axis=[1,2,3])

autoencoder = AutoEncoder((28,28,1),
                          4,
                          4,
                          [32,64,64,64],
                          [64,64,32,1],
                          [1,2,2,1],
                          [1,2,2,1],
                          2)

autoencoder.model.summary()

Solution

  • You almost got it! Piece of cake now. It should look something like this:

        img = Input(shape=self.img_shape)
        encoded_repr = self.encoder(img)
        reconstructed_img = self.decoder(encoded_repr)
    
        self.autoencoder = Model(img, reconstructed_img)
        self.autoencoder.compile(loss='mse', optimizer=optimizer)
    

    I hope it helps!