Search code examples
pythontensorflowkeras

How to have un-tracked weights in custom keras layer?


I would like to create a custom keras layer (a codebook for a VQVAE model.) While training I would like to have a tf.Variable which tracks the usage of each code so I can restart unused codes. So I created my Codebook layer as follows...

class Codebook(layers.Layer): 
     def __init__(self, num_codes, code_reset_limit = None, **kwargs): 
         super().__init__(**kwargs) 
         self.num_codes = num_codes 
         self.code_reset_limit = code_reset_limit 
         if self.code_reset_limit: 
             self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) 
     def build(self, input_shape): 
         self.codes = self.add_weight(name = 'codes',  
                                      shape = (self.num_codes, input_shape[-1]), 
                                      initializer = 'random_uniform',  
                                      trainable = True) 
         super().build(input_shape) 
                                                                                                             

The issue I have is that the Layer class finds the member variable self.code_counter and adds it to the list of weights which are saved with the layer. It also expects the self.code_counter to be present when weights are loaded which is not the case when I run in inference mode. How can I make it so keras does not track a variable in my layer. I do not want it persisted or to be part of the layers.weights.


Solution

  • I am a bit late with the answer, but I had the same problem and came across the question without an answer. Now, I have found an answer that works for Keras 2 and Keras 3, so I am sharing it here for others encountering the same question.

    To prevent TensorFlow and Keras from tracking variables one needs to encapsulate the variable in a class that TensorFlow and Keras do not handle in the tracking module. The list of classes that are automatically tracked for Keras 3 are: keras.Variable, list, dict, tuple, and NamedTuple (see here). For Keras 2 the list of objects is not so easy to find but appears to include tf.Variable (see the present question), dict, and list.

    The solution that did work in my context for keras.Variable and tf.Variable is to create dataclass encapsulating the Variable. Here the setup for tensorflow and keras 2.

    import tensorflow as tf
    from dataclasses import dataclass
    
    @dataclass
    class DoNotTrackContainer:
        data: tf.Variable
    
    

    In the code of the present question, this would then be used like this

     if self.code_reset_limit: 
         self.code_counter = DoNotTrackContainer(data=tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) )
    

    When accessing the counter the data attribute needs to be included in the path

      # for accessing the counter
      self.code_counter.data.assign_add(1) 
    

    For Keras 3 the Container becomes

    import keras
    from dataclasses import dataclass
    
    @dataclass
    class DoNotTrackContainer:
        data: keras.Variable