Search code examples
pythontensorflowkerasneural-networktensorflow2.0

Efficiently use Dense layers in parallel


I need to implement a layer in Tensorflow for a dataset of size N where each sample has a set of M independent features (each feature is represented by a tensor of dimension L). I want to train M dense layers in parallel, then concatenate the outputted tensors.

I could implement a layer using for loop as below:

class MyParallelDenseLayer(tf.keras.layers.Layer):
    
    def __init__(self, dense_kwargs, **kwargs):
        super().__init__(**kwargs)
        self.dense_kwargs = dense_kwargs
    
    def build(self, input_shape):
        self.N, self.M, self.L = input_shape
        self.list_dense_layers = [tf.keras.layers.Dense(**self.dense_kwargs) for a_m in range(self.M)]
        super().build(input_shape)
        
    def call(self, inputs):
        parallel_output = [self.list_dense_layers[i](inputs[:, i]) for i in range(self.M)]
        return tf.keras.layers.Concatenate()(parallel_output)

But the for loop in the 'call' function makes my layer extremely slow. Is there a faster way to do this layer?


Solution

  • This should be doable using einsum. Expand this layer to your liking with activation functions and whatnot.

    class ParallelDense(tf.keras.layers.Layer):
        def __init__(self, units, **kwargs):
            super().__init__(**kwargs)
            self.units = units
    
        def build(self, input_shape):
            super().build(input_shape)
            self.kernel = self.add_weight(shape=[input_shape[1], input_shape[2], self.units])
    
        def call(self, inputs):
            return tf.einsum("bml, mlk -> bmk", inputs, self.kernel)
    

    Test it:

    b = 16  # batch size
    m = 200
    l = 4  # no. of input features per m
    k = 10  # no. of output features per m
    
    layer = ParallelDense(k)
    inp = tf.random.normal([b, m, l])
    
    print(layer(inp).shape)
    

    (16, 200, 10)