Search code examples
tensorflowmachine-learningautoencodergradienttapecustom-training

Custom Training Loop for Tensorflow Variational Autoencoder: `tape.gradient(loss, decoder_model.trainable_weights)` Always Returns List Full of None's


I am trying to write a custom training loop for a variational autoencoder (VAE) that consists of two separate tf.keras.Model objects. The objective of this VAE is multi-class classification. As usual, the outputs of the encoder model are fed as inputs to the decoder model. The decoder is a recurrent decoder. Also as usual, two loss functions are involved in the VAE: reconstruction loss (categorical cross entropy) and latent loss. The inspiration for my current architecture is based on a pytorch implementation at this github.

Problem: Whenever I calculate the gradients using tape.gradient(loss, decoder.trainable_weights) for the decoder model, the returned list has only NoneType objects for each element. I assume I am making some mistake with the use of the reconstruction_tensor, which is near the bottom of the code I have written below. Since I need to have the iterative decoding process, how can I use something like the reconstruction_tensor without returning a list of NoneType elements for gradients? You may run the code using this colab notebook if you wish.

To further clarify what the tensors in this problem look like, I shall illustrate the original input, the zeros tensor to which predicted 'tokens' will be assigned, and a single update of the zeroes tensor based on the predicted 'tokens' from the decoder:

Example original input tensor of shape (batch_size, max_seq_length, num_classes):
 _    _         _     _         _     _         _    _
|    |  1 0 0 0  |   |  0 1 0 0  |   |  0 0 0 1  |    |
|    |  0 1 0 0  |   |  1 0 0 0  |   |  1 0 0 0  |    |
|_   |_ 0 0 1 0 _| , |_ 0 0 0 1 _|,  |_ 0 1 0 0 _|   _|

Initial zeros tensor:
 _    _         _     _         _     _         _    _
|    |  0 0 0 0  |   |  0 0 0 0  |   |  0 0 0 0  |    |
|    |  0 0 0 0  |   |  0 0 0 0  |   |  0 0 0 0  |    |
|_   |_ 0 0 0 0 _| , |_ 0 0 0 0 _|,  |_ 0 0 0 0 _|   _|

Example zeros tensor after a single iteration of the decoding loop:
 _    _                 _     _                 _     _                   _    _
|    |  0.2 0.4 0.1 0.3  |   |  0.1 0.2 0.6 0.1  |   |  0.7 0.05 0.05 0.2  |    |
|    |  0   0   0   0    |   |  0   0   0   0    |   |  0   0    0    0    |    |
|_   |_ 0   0   0   0   _| , |_ 0   0   0   0   _|,  |_ 0   0    0    0   _|   _|

Here is the code to reproduce the problem:

# Arbitrary data
batch_size = 3  
max_seq_length = 3
num_classes = 4
original_inputs = tf.one_hot(tf.argmax((np.random.randn(batch_size, max_seq_length, num_classes)), axis=2), depth=num_classes)
latent_dims = 5  # Must be less than (max_seq_length * num_classes)

def sampling(inputs):
    """Reparametrization function. Used for Lambda layer"""

    mus, log_vars = inputs
    epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mus))
    z = mus + tf.keras.backend.exp(log_vars/2) * epsilon

    return z

def latent_loss_fxn(mus, log_vars):
    """Return latent loss for means and log variance."""

    return -0.5 * tf.keras.backend.mean(1. + log_vars - tf.keras.backend.exp(log_vars) - tf.keras.backend.pow(mus, 2))

class DummyEncoder(tf.keras.Model):
    def __init__(self, latent_dimension):
        """Define the hidden layer (bottleneck) and sampling layers"""

        super().__init__()
        self.hidden = tf.keras.layers.Dense(units=32)
        self.dense_mus = tf.keras.layers.Dense(units=latent_dimension)
        self.dense_log_vars = tf.keras.layers.Dense(units=latent_dimension)
        self.sampling = tf.keras.layers.Lambda(function=sampling)

    def call(self, inputs):
        """Define forward computation that outputs z, mu, log_var of input."""

        dense_projection = self.hidden(inputs)

        mus = self.dense_mus(dense_projection)
        log_vars = self.dense_log_vars(dense_projection)
        z = self.sampling([mus, log_vars])

        return z, mus, log_vars
        

class DummyDecoder(tf.keras.Model):
    def __init__(self, num_classes):
        """Define GRU layer and the Dense output layer"""

        super().__init__()
        self.gru = tf.keras.layers.GRU(units=1, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(units=num_classes, activation='softmax')

    def call(self, x, hidden_states=None):
        """Define forward computation"""

        outputs, h_t = self.gru(x, hidden_states)

        # The purpose of this computation is to use the unnormalized log
        # probabilities from the GRU to produce normalized probabilities via
        # the softmax activation function in the Dense layer
        reconstructions = self.dense(outputs)

        return reconstructions, h_t

# Instantiate the models
encoder_model = DummyEncoder(latent_dimension=5)
decoder_model = DummyDecoder(num_classes=num_classes)

# Instantiate reconstruction loss function
cce_loss_fxn = tf.keras.losses.CategoricalCrossentropy()

# Begin tape
with tf.GradientTape(persistent=True) as tape:
    # Flatten the inputs for the encoder
    reshaped_inputs = tf.reshape(original_inputs, shape=(tf.shape(original_inputs)[0], -1))

    # Encode the input
    z, mus, log_vars = encoder_model(reshaped_inputs, training=True)

    # Expand dimensions of z so it meets recurrent decoder requirements of
    # (batch, timesteps, features)
    z = tf.expand_dims(z, axis=1)

    ################################
    # SUSPECTED CAUSE OF PROBLEM
    ################################

    # A tensor that will be modified based on model outputs
    reconstruction_tensor = tf.Variable(tf.zeros_like(original_inputs))

    ################################
    # END SUSPECTED CAUSE OF PROBLEM
    ################################

    # A decoding loop to iteratively generate the next token (i.e., outputs)... 
    # in the sequence
    hidden_states = None
    for ith_token in range(max_seq_length):

        # Reconstruct the ith_token for a given sample in the batch
        reconstructions, hidden_states = decoder_model(z, hidden_states, training=True)

        # Reshape the reconstructions to allow assigning to reconstruction_tensor
        reconstructions = tf.squeeze(reconstructions)

        # After the loop is done iterating, this tensor is the model's prediction of the 
        # original inputs. Therefore, after a single iteration of the loop, 
        # a single token prediction for each sample in the batch is assigned to
        # this tensor.
        reconstruction_tensor = reconstruction_tensor[:, ith_token,:].assign(reconstructions)

    # Calculates losses
    recon_loss = cce_loss_fxn(original_inputs, reconstruction_tensor)
    latent_loss = latent_loss_fxn(mus, log_vars)
    loss = recon_loss + latent_loss

# Calculate gradients
encoder_gradients = tape.gradient(loss, encoder_model.trainable_weights)
decoder_gradients = tape.gradient(loss, decoder_model.trainable_weights)

# Release tape
del tape

# Inspect gradients
print('Valid Encoder Gradients:', not(None in encoder_gradients))
print('Valid Decoder Gradients:', not(None in decoder_gradients), ' -- ', decoder_gradients)

>>> Valid Encoder Gradients: True
>>> Valid Decoder Gradients: False -- [None, None, None, None, None]

Solution

  • Found a 'solution' to my problem:

    There must be some problem with the use of a tf.Variable in the GradientTape() context manager. While I do not know what that problem is, by replacing the reconstructions_tensor with a list, appending to that list during decoding iterations, and then stacking the list, gradients can be computed without a problem. The colab notebook reflects the changes. See code snippet below for the fix:

    ....
    ....
    with tf.GradientTape(persistent=True) as tape:
        ....
        ....
    
        # FIX
        reconstructions_tensor = []
    
        hidden_states = None
        for ith_token in range(max_seq_length):
            # Reconstruct the ith_token for a given sample in the batch
            reconstructions, hidden_states = decoder_model(z, hidden_states, training=True)
    
            # Reshape the reconstructions
            reconstructions = tf.squeeze(reconstructions)
    
            # FIX
            # Appending to the list which will eventually be stacked
            reconstructions_tensor.append(reconstructions)
        
        # FIX
        # Stack the reconstructions along axis=1 to get same result as previous assignment with zeros tensor
        reconstructions_tensor = tf.stack(reconstructions_tensor, axis=1)
    ....
    ....
    # Successful gradient computations and subsequent optimization of models
    # ....
    

    Edit 1:

    I don't think this 'solution' ideal if one has a model that can be run in graph mode. My limited understanding is that graph mode does not do well with python objects such as list.