Search code examples
pythontensorflowkeraskeras-layer

Is there a way to zero out weak weights during training? For example if the absolute value of a weight is lower than .05 just set that weight to 0


I was having trouble seems it seems you can't edit a tensor directly or simply convert it to numpy and edit it in that form during training either. To some extent what I'm looking for is the opposite of the clip function that exists in both Tensorflow and numpy. Instead of making sure all values are between min and max, I want to make to switch all values between min and max to 0. Likely min and max will be the same value so it becomes zeroing any weight whose absolute value is less than some input value.

class DeleteWeakConnectionsDenseLayer(keras.layers.Layer):
def __init__(self, units, weak_threshold, **kwargs):
    super(DeleteWeakConnectionsDenseLayer, self).__init__(**kwargs)
    self.units = units
    self.weak_threshold = weak_threshold

def build(self, input_shape):

    self.w = self.add_weight(
        shape=(input_shape[-1], self.units),
        initializer="random_normal",
        trainable=True,
    )
    self.b = self.add_weight(
        shape=(self.units,), initializer="random_normal", trainable=True
    )

def call(self, inputs, training=False):
    if training:
        new_weights = #Code Here such that weights whose absolute value is below self.weakthreshold are reassigned to 0
        self.w.assign(new_weights)  # Assign preserves tf.Variable
    else:
        pass #could think about multiplying all weights by a constant here
    return tf.nn.relu(tf.matmul(inputs, self.w) + self.b)

Solution

  • Try this code:

      def call(self, inputs, training=False):
          if training:
              mask = tf.abs(self.w) > self.weak_threshold
              new_weights = self.w * tf.cast(mask, tf.float32)
              self.w.assign(new_weights)  # Assign preserves tf.Variable
          else:
              pass #could think about multiplying all weights by a constant here
          return tf.nn.relu(tf.matmul(inputs, self.w) + self.b)