Search code examples
pythonpytorchfeature-extractionautoencoderencoder-decoder

Extracting features of the hidden layer of an autoencoder using Pytorch


I am following this tutorial to train an autoencoder.

The training has gone well. Next, I am interested to extract features from the hidden layer (between the encoder and decoder).

How should I do that?


Solution

  • The cleanest and most straight-forward way would be to add methods for creating partial outputs -- this can be even be done a posteriori on a trained model.

    from torch import Tensor
    
    class AE(nn.Module):
        def __init__(self, **kwargs):
            ...
    
        def encode(self, features: Tensor) -> Tensor:
            h = torch.relu(self.encoder_hidden_layer(features))
            return torch.relu(self.encoder_output_layer(h))
    
        def decode(self, encoded: Tensor) -> Tensor:
            h = torch.relu(self.decoder_hidden_layer(encoded))
            return torch.relu(self.decoder_output_layer(h))
    
        def forward(self, features: Tensor) -> Tensor:
            encoded = self.encode(features)
            return self.decode(encoded)
    

    You can now query the model for encoder hidden states by simply calling encode with the corresponding input tensor.

    If you'd rather not add any methods to the base class (I don't see why), you could alternatively write an external function:

    def get_encoder_state(model: AE, features: Tensor) -> Tensor:
       return torch.relu(model.encoder_output_layer(torch.relu(model.encoder_hidden_layer(features))))