Search code examples
pythonkerasjupyterkeras-layer

How to handle efficiently Repeated Shared Layers in Keras


I am a noob in Keras deep learning. I am setting up a neural network model in Keras whose inputs, xi (each input has a pair of variable) are pre-processed by smaller dense layers of network. After that, the smaller dense layer spits out two new variables xf (like an encoder), that are concatenated and then serves as input in a rather larger dense network.

At the pre-proseccing layers, since it is the same layer running over each input xi, I was wondering whether it can be efficiently coded by a loop.

I have tried with for-loop but it seems not working as such tensor objects do not support item assignment. So, I have written it in the following simple manner and the code is working fine. However, this is written for 5 input; in practice, my number of input is a large variable that can go up to 500. Hence I am looking for an efficient method is necessary.

shared_dense_lvl1 = Dense(4, activation='sigmoid')
shared_dense_lvl2 = Dense(4, activation='sigmoid')
shared_dense_lvl3 = Dense(2, activation='sigmoid')

x1i = Input(shape=(2,))
x11 = shared_dense_lvl1(x1i)
x12 = shared_dense_lvl2(x11)
x1f = shared_dense_lvl3(x12)

x2i = Input(shape=(2,))
x21 = shared_dense_lvl1(x2i)
x22 = shared_dense_lvl2(x21)
x2f = shared_dense_lvl3(x22)

x3i = Input(shape=(2,))
x31 = shared_dense_lvl1(x3i)
x32 = shared_dense_lvl2(x31)
x3f = shared_dense_lvl3(x32)

x4i = Input(shape=(2,))
x41 = shared_dense_lvl1(x4i)
x42 = shared_dense_lvl2(x41)
x4f = shared_dense_lvl3(x42)

x5i = Input(shape=(2,))
x51 = shared_dense_lvl1(x5i)
x52 = shared_dense_lvl2(x51)
x5f = shared_dense_lvl3(x52)

xy0 = concatenate([x1f, x2f, x3f, x4f, x5f])
xy  = Dense(2, activation='relu',name='xy')(xy0)

y1  = Dense(50, activation='sigmoid',name='y1')(xy)
y2  = Dense(50, activation='sigmoid',name='y2')(y1)
y3  = Dense(250, activation='sigmoid',name='y3')(y2)

merged = Model(inputs=[x1i,x2i,x3i,x4i,x5i],outputs=[y3])

Solution

  • The model you wrote is equivalent to:

    inp = Input(shape=(5, 2,))
    l1 = Dense(4, ...)(inp)
    l2 = Dense(4, ...)(l1)
    l3 = Dense(2, ...)(l2)
    xy0 = Flatten()(l3)
    

    i.e. if you have an input shape with more than 2 dimensions, for instance a shape such as (batch_size, time_steps, n_features), dense layers are applied with shared weights to all time_steps.

    Is your data something like a time-series ? you may consider using a recurrent or convolutional network in that case.

    You also seem to using lots of sigmoid activations. Beware of disappearing gradients. If you find out that the network is not training, check wether the gradients are disappearing by printing the gradients for a specific training batch. This package for instance, may help: https://github.com/philipperemy/keract