Search code examples
pythontensorflowautoencoderpolar-coordinates

Autoencoders and Polar Coordinates


Can an autoencoder learn the transformation into polar coordinates? If a set of 2D data lies approximately on a circle, there is a lower-dimensional manifold, parameterized by the angle, that describes the data 'best'.

I tried various versions without success.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

class DeepAutoEncoder(Model):
  def __init__(self, dim_data: int, num_hidden: int,
               num_comp: int,
               activation: str = 'linear'):
    super(DeepAutoEncoder, self).__init__()
    self.encoder = tf.keras.Sequential([
      layers.Dense(num_hidden, activation=activation),
      layers.Dense(num_comp, activation=activation),
    ])

    self.decoder = tf.keras.Sequential([
      layers.Dense(num_hidden, activation=activation),
      layers.Dense(dim_data, activation='linear')
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded



# Data
num_obs = 1000
np.random.seed(1238)
e = np.random.randn(num_obs, 1)
t = np.linspace(0, 2*np.pi, num_obs)
x = 1 * np.cos(t)
y = np.sin(t) + 0.2*e[:, 0]
X = np.column_stack((x, y))

num_comp = 1
activations = ['linear', 'sigmoid']
ae = {a: None for a in activations}
for act in activations:
    ae[act] = DeepAutoEncoder(dim_data=2, num_comp=num_comp,
                              num_hidden=3, activation=act)
    ae[act].compile(optimizer=Adam(learning_rate=0.01),
                    loss='mse')
    ae[act].build(input_shape=(None, 2))
    ae[act].summary()
    history = ae[act].fit(X, X, epochs=200,
                          batch_size=32,
                          shuffle=True)
    ae[act].summary()
    plt.plot(history.history["loss"], label=act)
    plt.legend()

f, axs = plt.subplots(2, 2)
for i, a in enumerate(activations):
    axs[0, i].plot(x, y, '.', c='k')
    z = ae[a].encoder(X)
    # x_ae = ae[a].decoder(ae[a].encoder(X))
    x_ae = ae[a](X)
    axs[0, i].plot(x_ae[:, 0], x_ae[:, 1], '.')
    # axs[0, i].plot(x_pca[:, 0], x_pca[:, 1], '.', c='C3')
    axs[1, i].plot(z)
    axs[0, i].axis('equal')
    axs[0, i].set(title=a)


The reconstructed data looks like:

enter image description here

I assume that the reason is that the transformation sigmoid(W * z + b) is far away from a non-linear matrix [[cos(theta) sin(theta)], [-sin(theta) sin(theta)]] required to map the latent variable back into the original space.

Any thoughts would be great!

Thank you very much.


Solution

  • Can an autoencoder learn the transformation into polar coordinates? If a set of 2D data lies approximately on a circle, there is a lower-dimensional manifold, parameterized by the angle, that describes the data 'best'.

    Neural nets can be trained to learn arbitrary transformations, including analytical ones like mappings between coordinate systems. In the present case the net will encode the (x, y) coordinates using an arbitrary encoding that will correlate strongly with the angle (a 'pseudo-angle', if you like).

    The neural net in your question is trying to encode 3 key variables into a 1D space (the single encoder unit): the sign of x, the sign of y, and their relative sizes. These three pieces of information are all required for determining the correct quadrant. I think the main reason the net is not learning is because its capacity is too little for capturing arctan2 complexity.

    Suppose we limit the data to x > 0; in this case the only thing the net needs to encode is the sign of y and its relative size to x. In this case your net works fine, as it just needs to learn arctan:

    enter image description here

    The figure below illustrates how the learnt encoding carries information about the angle, allowing the net to uniquely determine the location along the circle.

    enter image description here

    Notice how there is a unique value of the encoding for each point along the circumference. There's a linear relationship between the inputs and their reconstructions, indicating that the net has learnt to reproduce them.

    As soon as you allow x to be both positive or negative, it needs to learn the more complex arctan2 function. The net fails and instead just captures arctan; your results show that y is being encoded correctly, but for any y it can't determine which side of the plane x should be, resulting in an average x of 0 with a correct projection of the points onto y.

    enter image description here

    The figure on the left illustrates what is happening. The encoding is correct if you trace it from +90 to -90, but then it repeats. In other words, it is capturing the correct angles in the right-hand plane, but they are duplicated for the left-hand plane. Both positive and negative x correspond to the same encoding, leading to x averaging out to 0. The second figure shows how, for any x, it basically predicts 0, whilst it learns the correct positioning for y in the third figure.

    I made the following changes, all of which I found were important for improving performance using this dataset and the given model:

    • Make the encoder deeper (and to a lesser extent, the decoder as well)
    • Use tanh activations rather than ReLU or sigmoid, consistent with the data's range
    • Use a small batch size, giving the net more steps for exploring the loss space

    I've tried to keep the architecture close to the original in order to demonstrate the point of network depth. The results below are with an easier dataset (less noise), showing the model's performance after 75 epochs:

    Model comprises 135 trainable parameters
    [epoch  1/75] trn loss: 0.291 [rmse: 0.540] | val loss: 0.284 [rmse: 0.533]
    [epoch  5/75] trn loss: 0.252 [rmse: 0.502] | val loss: 0.248 [rmse: 0.498]
    ...
    [epoch 70/75] trn loss: 0.005 [rmse: 0.072] | val loss: 0.005 [rmse: 0.074]
    [epoch 75/75] trn loss: 0.009 [rmse: 0.095] | val loss: 0.005 [rmse: 0.070]
    

    enter image description here

    The net has learnt a unique encoding for each point on the circle. There's a discontinuity near y=0 that I haven't looked into.

    enter image description here

    The recons generally track the inputs.

    In going from modelling half of the plane (arctan) to the full plane (arctan2), the model size increases from about 27 to 135 parameters (5x), and the depth increases from 1 layer to 9. Whilst arctan is a single equation, arctan2 is discontinuous below x < 0, meaning it is defined by 3 equations rather than 1, and that's aside from the 2/3 other points where special cases apply. It seems like the depth grows exponentially with the additional complexity, suggesting that the model needs higher-level encodings rather than merely more detailed ones. There could be more efficient architectures that are more expressive with lesser depth, in which case we wouldn't need as many layers, but this example sticks to a simple stacked arrangement.


    PyTorch example code.

    import numpy as np
    from matplotlib import pyplot as plt
    
    #Data for testing as per OP
    num_obs = 1000
    np.random.seed(1238)
    e = np.random.randn(num_obs)
    #Shuffle it in advance
    t = np.linspace(0, 2 * np.pi, num_obs)[np.random.permutation(num_obs)]
    
    x = np.cos(t)
    y = np.sin(t) + 0.2 * e / 10
    data = np.column_stack((x, y))
    # data = data[x > 0, :] #Limit to RH plane
    
    
    #
    #Split the data (just train & validation for this demo)
    #
    n_train = int(0.7 * len(data))
    
    data_trn = data[:n_train]
    data_val = data[n_train:]
    
    f, (ax_trn, ax_val) = plt.subplots(
        ncols=2, figsize=(5.2, 3), sharex=True, sharey=True, layout='tight'
    )
    
    for ax, arr in [[ax_trn, data_trn], [ax_val, data_val]]:
        ax.scatter(arr[:, 0], arr[:, 1], marker='.', s=2, color='dodgerblue')
    
    ax_trn.set_title('train')
    ax_trn.set(xlabel='x', ylabel='y')
    ax_trn.spines[['top', 'right']].set_visible(False)
    
    ax_val.set_title('validation')
    ax_val.set_xlabel('x')
    ax_val.spines[['top', 'right', 'left']].set_visible(False)
    ax_val.tick_params(axis='y', left=False)
    
    #
    # Prepare data for training
    #
    import torch
    from torch import nn
    
    #Data to float tensors
    X_trn = torch.tensor(data_trn).float()
    X_val = torch.tensor(data_val).float()
    
    #
    #Define the model
    #
    torch.manual_seed(1000)
    
    n_features = X_trn.shape[1]
    hidden_size = 3
    latent_size = 1
    
    activation_layer = nn.Tanh()
    
    encoder = nn.Sequential(
        nn.Linear(n_features, hidden_size),
        activation_layer,
        
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
    
        nn.Linear(hidden_size, latent_size),    
    )
    
    decoder = nn.Sequential(
        activation_layer,
    
        nn.Linear(latent_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        nn.Linear(hidden_size, hidden_size),
        activation_layer,
        
        nn.Linear(hidden_size, n_features),
    )
    
    autoencoder = nn.Sequential(encoder, decoder)
    
    print(
        'Model comprises',
        sum([p.numel() for p in autoencoder.parameters() if p.requires_grad]),
        'trainable parameters'
    )
    
    optimiser = torch.optim.NAdam(autoencoder.parameters())
    loss_fn = nn.MSELoss()
    
    #
    # Training loop
    #
    metrics_dict = dict(epoch=[], trn_loss=[], val_loss=[])
    
    for epoch in range(n_epochs := 75):
        autoencoder.train()
    
        train_shuffled = X_trn[torch.randperm(len(X_trn))]
        for sample in train_shuffled:
            recon = autoencoder(sample).ravel()
            loss = loss_fn(recon, sample)
    
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        #/end of epoch
    
        if not (epoch == 0 or (epoch + 1) % 5 == 0):
            continue
    
        autoencoder.eval()
        with torch.no_grad():
            trn_recon = autoencoder(X_trn)
            val_recon = autoencoder(X_val)
    
            trn_encodings = encoder(X_trn)
            val_encodings = encoder(X_val)
            
        trn_loss = loss_fn(trn_recon, X_trn)
        val_loss = loss_fn(val_recon, X_val)
    
        print(
            f'[epoch {epoch + 1:>2d}/{n_epochs:>2d}]',
            f'trn loss: {trn_loss:>5.3f} [rmse: {trn_loss**0.5:>5.3f}] |',
            f'val loss: {val_loss:>5.3f} [rmse: {val_loss**0.5:>5.3f}]'
        )
    
        metrics_dict['epoch'].append(epoch + 1)
        metrics_dict['trn_loss'].append(trn_loss)
        metrics_dict['val_loss'].append(val_loss)
    
    #Overlay results
    for ax, recon in [[ax_trn, trn_recon], [ax_val, val_recon]]:
        ax.scatter(recon[:, 0], recon[:, 1], color='crimson', marker='.', s=2)
    
    #Legend
    ax_trn.scatter([], [], color='dodgerblue', marker='.', s=5, label='data')
    ax_trn.scatter([], [], color='crimson', marker='.', s=5, label='recon')
    f.legend(framealpha=1, scatterpoints=5, loc='upper left', labelspacing=0.05)
    
    #View learning curve
    f, ax = plt.subplots(figsize=(6, 2))
    
    for key in ['trn_loss', 'val_loss']:
        ax.plot(
            metrics_dict['epoch'][0:], metrics_dict[key][0:],
            marker='o', linestyle='-' if 'trn' in key else '--', label=key[:3]
        )
    ax.set(xlabel='epoch', ylabel='MSE loss')
    ax.legend(framealpha=0, loc='upper right')
    ax.spines[['top', 'right']].set_visible(False)
    
    #View encodings
    f, axs = plt.subplots(ncols=4, figsize=(10, 3), layout='tight')
    cmap = 'seismic'
    
    ax = axs[0]
    im = ax.scatter(X_trn[:, 0], X_trn[:, 1], c=trn_encodings, cmap=cmap, marker='.')
    ax.set(xlabel='x', ylabel='y', title='inputs & learnt encoding')
    
    ax = axs[1]
    ax.scatter(X_trn[:, 0], trn_recon[:, 0], c=trn_encodings, cmap=cmap, marker='.')
    ax.set(xlabel='x', ylabel='recon_x', title='x recon')
    
    ax = axs[2]
    ax.scatter(X_trn[:, 1], trn_recon[:, 1], c=trn_encodings, cmap=cmap, marker='.')
    ax.set(xlabel='y', ylabel='recon_y', title='y recon')
    
    ax = axs[3]
    ax.scatter(trn_recon[:, 0], trn_recon[:, 1], c=trn_encodings, cmap=cmap, marker='.')
    ax.set(xlabel='x_recon', ylabel='y_recon', title='recon')
    
    [ax.set(xlim=[-1.5, 1.5], ylim=[-1.5, 1.5]) for ax in axs]
    f.colorbar(im, label='encoder output', ax=axs[3])