Search code examples
keraskeras-layertf.keras

Keras multiply parallel layers' outputs with constrained weigths


I have 3 parallel MLPs and want to obtain the following in Keras:

Out = W1 * Out_MLP1 + W2 * Out_MLP2 + W3 * Out_MLP3

where Out_MLPs are output layer of each MLP and have dimension of (10,) and W1, W2 and W3 are three trainable weights (floats) where they satisfy the following condition:

W1 + W2 + W3 = 1

What is the best way to implement this with Keras functional API? What if we had N parallel layers?


Solution

  • what you need is to apply a softmax on a set of learnable weights, in order to grant that they sum up to 1.

    We initialize our learnable weights in a custom layer. this layer receives the output of our MLPs and combines them following our logic W1 * Out_MLP1 + W2 * Out_MLP2 + W3 * Out_MLP3. the output will be a tensor of shape (10,).

    class W_ADD(Layer):
    
        def __init__(self, n_output):
            super(W_ADD, self).__init__()
            self.W = tf.Variable(initial_value=tf.random.uniform(shape=[1,1,n_output], minval=0, maxval=1),
                trainable=True) # (1,1,n_inputs)
    
        def call(self, inputs):
    
            # inputs is a list of tensor of shape [(n_batch, n_feat), ..., (n_batch, n_feat)]
            # expand last dim of each input passed [(n_batch, n_feat, 1), ..., (n_batch, n_feat, 1)]
            inputs = [tf.expand_dims(i, -1) for i in inputs]
            inputs = Concatenate(axis=-1)(inputs) # (n_batch, n_feat, n_inputs)
            weights = tf.nn.softmax(self.W, axis=-1) # (1,1,n_inputs)
            # weights sum up to one on last dim
    
            return tf.reduce_sum(weights*inputs, axis=-1) # (n_batch, n_feat)
    

    in this dummy example, I create a network that has 3 parallel MLPs

    inp1 = Input((100))
    inp2 = Input((100))
    inp3 = Input((100))
    x1 = Dense(32, activation='relu')(inp1)
    x2 = Dense(32, activation='relu')(inp2)
    x3 = Dense(32, activation='relu')(inp3)
    x1 = Dense(10, activation='linear')(x1)
    x2 = Dense(10, activation='linear')(x2)
    x3 = Dense(10, activation='linear')(x3)
    mlp_outputs = [x1,x2,x3]
    out = W_ADD(n_output=len(mlp_outputs))(mlp_outputs)
    
    m = Model([inp1,inp2,inp3], out)
    m.compile('adam','mse')
    
    X1 = np.random.uniform(0,1, (1000,100))
    X2 = np.random.uniform(0,1, (1000,100))
    X3 = np.random.uniform(0,1, (1000,100))
    y = np.random.uniform(0,1, (1000,10))
    
    m.fit([X1,X2,X3], y, epochs=10)
    

    as you can see this is easily generalizable in case of N parallel layers