Search code examples
pythontensorflowmachine-learningkerasrecurrent-neural-network

For loop with GRUCell in call method of subclassed tf.keras.Model


I have subclassed tf.keras.Model and I use tf.keras.layers.GRUCell in a for loop to compute sequences 'y_t' (n, timesteps, hidden_units) and final hidden states 'h_t' (n, hidden_units). For my loop to output 'y_t', I update a tf.Variable after each iteration of the loop. Calling the model with model(input) is not a problem, but when I fit the model with the for loop in the call method I get either a TypeError or a ValueError.

Please note, I cannot simply use tf.keras.layers.GRU because I am trying to implement this paper. Instead of just passing x_t to the next cell in the RNN, the paper performs some computation as a step in the for loop (they implement in PyTorch) and pass the result of that computation to the RNN cell. They end up essentially doing this: h_t = f(special_x_t, h_t-1).

Please see the model below that causes the error:

class CustomGruRNN(tf.keras.Model):
    def __init__(self, batch_size, timesteps, hidden_units, features, **kwargs):

        # Inheritance
        super().__init__(**kwargs)

        # Args
        self.batch_size = batch_size
        self.timesteps = timesteps
        self.hidden_units = hidden_units        

        # Stores y_t
        self.rnn_outputs = tf.Variable(tf.zeros(shape=(batch_size, timesteps, hidden_units)), trainable=False)

        # To be used in for loop in call
        self.gru_cell = tf.keras.layers.GRUCell(units=hidden_units)

        # Reshape to match input dimensions
        self.dense = tf.keras.layers.Dense(units=features)

    def call(self, inputs):
        """Inputs is rank-3 tensor of shape (n, timesteps, features) """

        # Initial state for gru cell
        h_t = tf.zeros(shape=(self.batch_size, self.hidden_units))

        for timestep in tf.range(self.timesteps):
            # Get the the timestep of the inputs
            x_t = tf.gather(inputs, timestep, axis=1)  # Same as x_t = inputs[:, timestep, :]

            # Compute outputs and hidden states
            y_t, h_t = self.gru_cell(x_t, h_t)
            
            # Update y_t at the t^th timestep
            self.rnn_outputs = self.rnn_outputs[:, timestep, :].assign(y_t)

        # Outputs need to have same last dimension as inputs
        outputs = self.dense(self.rnn_outputs)

        return outputs

An example that would throw the error:

# Arbitrary values for dataset
num_samples = 128
batch_size = 4
timesteps = 5
features = 10

# Arbitrary dataset
x = tf.random.uniform(shape=(num_samples, timesteps, features))
y = tf.random.uniform(shape=(num_samples, timesteps, features))

train_data = tf.data.Dataset.from_tensor_slices((x, y))
train_data = train_data.shuffle(batch_size).batch(batch_size, drop_remainder=True)

# Model with arbitrary hidden units
model = CustomGruRNN(batch_size, timesteps, hidden_units=5)
model.compile(loss=tf.keras.losses.MeanSquaredError(), optimizer=tf.keras.optimizers.Adam())

When running eagerly:

model.fit(train_data, epochs=2, run_eagerly=True)

Epoch 1/2 WARNING:tensorflow:Gradients do not exist for variables ['stack_overflow_gru_rnn/gru_cell/kernel:0', 'stack_overflow_gru_rnn/gru_cell/recurrent_kernel:0', 'stack_overflow_gru_rnn/gru_cell/bias:0'] when minimizing the loss. ValueError: substring not found ValueError

When not running eagerly:

model.fit(train_data, epochs=2, run_eagerly=False)

Epoch 1/2 TypeError: in user code: TypeError: Can not convert a NoneType into a Tensor or Operation.


Solution

  • Edit:

    While the TensorFlow guide answer suffices, I think my self-answered question involving custom cells for RNNs is a much better option. Please see this answer. Using a custom RNN cell removes the need to use tf.Transpose and tf.TensorArrayand thus lowers complexity of the code while simultaneously improving readability.

    Original Self-Answer:

    The use of the DynamicRNN described near the bottom of TensorFlow's Guide to Effective TensorFlow2 solves my problem.

    To expand briefly on the DynamicRNN's conceptual use, an RNN cell is defined, in my case GRU, and then any number of custom steps can be defined within the tf.range loop. Variables should be tracked using tf.TensorArray objects outside the loop but inside the call method itself, and the sizes of such arrays can be determined by simply calling the .shape method of (input) tensors. Notably, the DynamicRNN object works in model fit, wherein the default execution mode is 'Graph' mode as opposed to the slower 'Eager Execution' mode.

    Lastly, one might require the use of a 'DynamicRNN' because by default, the `tf.keras.layers.GRU' computation is loosely described by the following recurrent logic (assume that 'f' defines a GRU cell):

    # Numpy is used here for ease of indexing, but in general you should use
    # tensors and transpose them accordingly (see the previously linked guide)
    inputs = np.random.randn((batch, total_timesteps, features))
    
    # List for tracking outputs -- just for simple demonstration... again please see the guide for more details
    outputs = []
    
    # Initialize the 'hidden state' (often referred to as h_naught and denoted h_0) of the RNN cell
    state_at_t_minus_1 = tf.zeros(shape=(batch, hidden_cell_units))
    
    # Iterate through the input until all timesteps in the sequence have been 'seen' by the GRU cell function 'f'
    for timestep_t in total_timesteps:
        # This is of shape (batch, features)
        input_at_t = inputs[:, timestep_t, :]
    
        # output_at_t of shape (batch, hidden_units_of_cell) and state_at_t (batch, hidden_units_of_cell)
        output_at_t, state_at_t = f(input_at_t, state_at_t_minus_1)
        outputs.append(output_at_t)
    
        # When the loop restarts, this variable will be used in the next GRU Cell function call 'f'
        state_at_t_minus_1 = state_at_t
    

    One might wish to add other steps in the for loop of the recurrent logic (e.g., dense layers, other layers, etc.) to modify the inputs and states passed to the GRU Cell function 'f'. This is one motivation of the DynamicRNN.