Search code examples
tensorflowkerasdeep-learningconcatenationweighted-average

How can I add trainable weights while concatenating layers


I am trying to concatenate two layers in such a way that layers are assigned trainable weights while concatenating. The idea behind this is that my model can determine which layer should be given higher weights while concatenating.

I have read this code [https://stackoverflow.com/a/62595957/12848819][1]

class WeightedAverage(Layer):

def __init__(self, n_output):
    super(WeightedAverage, self).__init__()
    self.W = tf.Variable(initial_value=tf.random.uniform(shape=[1,1,n_output], minval=0, maxval=1),
        trainable=True) # (1,1,n_inputs)

def call(self, inputs):

    # inputs is a list of tensor of shape [(n_batch, n_feat), ..., (n_batch, n_feat)]
    # expand last dim of each input passed [(n_batch, n_feat, 1), ..., (n_batch, n_feat, 1)]
    inputs = [tf.expand_dims(i, -1) for i in inputs]
    inputs = Concatenate(axis=-1)(inputs) # (n_batch, n_feat, n_inputs)
    weights = tf.nn.softmax(self.W, axis=-1) # (1,1,n_inputs)
    # weights sum up to one on last dim

    return tf.reduce_sum(weights*inputs, axis=-1) # (n_batch, n_feat)

but this one performs the weighted average of the layers. Please help. Let me know if you more questions. Thanks.


Solution

  • I have used a weighted sum (not an average), to similar effect

    class WeightedSum(layers.Layer):
        """A custom keras layer to learn a weighted sum of tensors"""
    
        def __init__(self, **kwargs):
            super(WeightedSum, self).__init__(**kwargs)
    
        def build(self, input_shape=1):
            self.a = self.add_weight(name='alpha',
                                     shape=(1),
                                     initializer=tf.keras.initializers.Constant(0.5),
                                     dtype='float32',
                                     trainable=True,
                                     constraint=tf.keras.constraints.min_max_norm(
                                         max_value=1, min_value=0))
            super(WeightedSum, self).build(input_shape)
    
        def call(self, model_outputs):
            return self.a * model_outputs[0] + (1 - self.a) * model_outputs[1]
    
        def compute_output_shape(self, input_shape):
            return input_shape[0]