Search code examples
pythondeep-learningneural-networkpytorchautoencoder

Extracting hidden features from Autoencoders using Pytorch


Following the tutorials in this post, I am trying to train an autoencoder and extract the features from its hidden layer.

So here are my questions:

  1. In the autoencoder class, there is a "forward" function. However, I cannot see anywhere in the code that this function is called. So how does it get trained?

  2. My question above is because I feel if I want to extract the features, I should add another function (f"orward_hidden") in the autoencoder class:

     def forward(self, features):
         #print("in forward")
         #print(type(features))
         activation = self.encoder_hidden_layer(features)
         activation = torch.relu(activation)
         code = self.encoder_output_layer(activation)
         code = torch.relu(code)
         activation = self.decoder_hidden_layer(code)
         activation = torch.relu(activation)
         activation = self.decoder_output_layer(activation)
         reconstructed = torch.relu(activation)
         return reconstructed
    
     def forward_hidden(self, features):
         activation = self.encoder_hidden_layer(features)
         activation = torch.relu(activation)
         code = self.encoder_output_layer(activation)
         code = torch.relu(code)
         return code
    

Then, after training, which means after this line in the main code:

print("AE, epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs_AE, loss))

I can put the following code to retrieve the features from the hidden layer:

hidden_features = model_AE.forward_hidden(my_input)

Is this way correct? Still, I am wondering how the "forward" function was used for training. Because I cannot see it anywhere in the code that is being called.


Solution

  • forward is the essence of your model and actually defines what the model does.

    It is implicetly called with model(input) during the training.

    If you are askling how to extract intermediate features after running the model, you can register a forward-hook like described here, that will "catch" the values for you.