Search code examples
pythontensorflowdeep-learningrecurrent-neural-network

Accumulating output from a graph using tf.while_loop


I have an RNN that is stacked on top of a CNN. The CNN was created and trained separately. To clarify things, let's suppose the CNN takes input in the form of a [BATCH SIZE, H, W, C] placeholder (H = height, W = width, C = number of channels).

Now, when stacked on top of the RNN, the overall input to the combined network will have the shape: [BATCH SIZE, TIME SEQUENCE, H, W, C], i.e. each sample in the minibatch consists of TIME_SEQUENCE many images. Moreover, the time sequences are variable in length. There is a separate placeholder called sequence_lengths with shape [BATCH SIZE] that contains scalar values corresponding to the length of each sample in the minibatch. The value of TIME SEQUENCE corresponds to the maximum possible time sequence length, and for samples with smaller lengths, the remaining values are padded with zeros.

What I want to do

I want to accumulate the output from the CNN in a tensor of shape [BATCH SIZE, TIME SEQUENCE, 1] (the last dimension just contains the final score output by the CNN for each time sample for each batch element) so that I can forward this entire chunk of information to the RNN that is stacked on top of the CNN. The tricky thing is, I also want to be able to back-propagate the error from the RNN to the CNN (the CNN is already pre-trained, but I would like to fine-tune the weights a bit), so I have to stay inside the graph, i.e. I can't make any calls to session.run().

  • Option A: The easiest way would be to just reshape the overall network input tensor to [BATCH SIZE * TIME SEQUENCE, H, W, C]. The problem with this is that BATCH SIZE * TIME SEQUENCE may be as large as 2000, so I'm bound to run out of memory when trying to feed a batch that big into my CNN. And the batch size is too large for training anyway. Also, a lot of sequences are just padded zeros, and it'd be a waste of computation.

  • Option B: Use the tf.while_loop. My idea was to treat all the images along the time axis for a single minibatch element as a minibatch for the CNN. Essentially, the CNn would be processing batches of size [TIME SEQUENCE, H, W, C] at each iteration (not exactly TIME SEQUENCE many images every time; the exact number would depend on the sequence length). The code I have right now looks like this:

     # The output tensor that I want populated
     image_output_sequence = tf.Variable(tf.zeros([batch_size, max_sequence_length, 1], tf.float32))
    
     # Counter for the loop. I'll process one batch element per iteration.
     # One batch element contains a variable number of images for each time step. All these images will form a minibatch for the CNN.
     loop_counter = tf.get_variable('loop_counter', dtype=tf.int32, initializer=0)
    
     # Loop variables that will be passed to the body and cond methods
     loop_vars = [input_image_sequence, sequence_lengths, image_output_sequence, loop_counter]
     # input_image_sequence: [BATCH SIZE, TIME SEQUENCE, H, W, C]
     # sequence_lengths: [BATCH SIZE]
     # image_output_sequence: [BATCH SIZE, TIME SEQUENCE, 1]
    
     # abbreviations for vars in loop_vars:
     # iis --> input_image_sequence
     # sl --> sequence_lengths
     # ios --> image_output_sequence
     # lc --> loop_counter
     def cond(iis, sl, ios, lc):  
         return tf.less(lc, batch_size)
    
     def body(iis, sl, ios, lc):
         seq_len = sl[lc]  # the sequence length of the current batch element
         cnn_input_batch = iis[lc, :seq_len]  # extract the relevant portion (the rest are just padded zeros)
    
         # propagate this 'batch' through the CNN
         my_cnn_model.process_input(cnn_input_batch)
    
         # Pad the remaining indices
         padding = [[0, 0], [0, batch_size - seq_len]]
         padded_cnn_output = tf.pad(cnn_input_batch_features, paddings=padding, mode='CONSTANT', constant_values=0)
    
         # The problematic part: assign these processed values to the output tensor
         ios[lc].assign(padded_cnn_features)
         return [iis, sl, ios, lc + 1]
    
     _, _, result, _ = tf.while_loop(cond, body, loop_vars, swap_memory=True)
    

Inside my_cnn_model.process_input, I'm just passing the input through a vanilla CNN. All the variables created in it are with tf.AUTO_REUSE, so that should ensure that the while loop reuses the same weights for all the loop iterations.

The exact problem

image_output_sequence is a variable, but somehow when tf.while_loop calls the body method, it gets turned into a Tensor type object to which assignments can't be made. I get the error message: Sliced assignment is only supported for variables

This problem persists even if I use another format like using a tuple of BATCH SIZE Tensors each with dimensions [TIME SEQUENCE, H, W, C].

I'm open to a complete redesign of the code as well, as long as it gets the job done nicely.


Solution

  • The solution is to use an object of type TensorArray, which is specifically made to address such problems. The following line:

    image_output_sequence = tf.Variable(tf.zeros([batch_size, max_sequence_length, 1], tf.float32))
    

    is replaced by:

    image_output_sequence = tf.TensorArray(size=batch_size, dtype=tf.float32, element_shape=[max_sequence_length, 1], infer_shape=True)
    

    TensorArray doesn't actually require a fixed shape for each element, but for my case it is fixed, so it's better to enforce it.

    Then inside the body function, replace this:

    ios[lc].assign(padded_cnn_features)
    

    with:

    ios = ios.write(lc, padded_cnn_output)
    

    Then after the tf.while_loop statement, the TensorArray can be stacked to form a regular Tensor for further processing:

    stacked_tensor = result.stack()