Search code examples
pythontensorflowkerastf.keraskeras-layer

What is the proper way to add layers using Keras functional API?


I am trying to use Keras functional API to create a model with 2 branches but I need to add the output of the first branch (path23: m,n,5) with the second one (path10: m,n,1) and I need the output to be (output: m,n,1) and no (output: m,n,5) that is what I have now. I mean, I need to add the 5 tensors of the first branch with the tensor in the second branch without using broadcast. How can I do it?

Please check the code and the picture attached.

def define_neural_network_model(input_shape, outputs = 1):
  input_layer = Input(shape=(input_shape))
  # first path
  path10 = input_layer
  # second path
  path20 = input_layer
  path21 = Dense(1, use_bias = True, kernel_initializer=initializer)(path20)
  path22 = ReLU()(path21)
  path23 = Conv1D(filters=5, kernel_size=3, strides=1, padding="same", use_bias = True, kernel_initializer=initializer)(path22)
  # merge interpretation
  output = Add()([path10, path23])

  
  model = Model(inputs=input_layer, outputs=output)
  model._name = 'Recovery'
  return model


neural_network_model = define_neural_network_model(input_shape)
# model.summary()
plot_model(neural_network_model, to_file = 'generator_model.png', show_shapes = True, show_layer_names = True)  

model_scketch


Solution

  • It has been a time since I did this question. I going to answer it because maybe it will be useful to someone else.

    A sum through one of the dimensions after the convolution tensor can be done using a Lambda layer. Then, reshaping is needed since one dimension will be lost during the sum. The code is below, and the neural network diagram is attached (1).

    def define_neural_network_model(input_shape, outputs = 1):
      input_layer = Input(shape=(input_shape))
      # first path
      path10 = input_layer
      # second path
      path20 = input_layer
      path21 = Dense(1, use_bias = True, kernel_initializer=initializer)(path20)
      path22 = ReLU()(path21)
      path23 = Conv1D(filters=5, kernel_size=3, strides=1, padding="same", use_bias = True, kernel_initializer=initializer)(path22)
      path24 = Lambda(lambda x: tf.keras.backend.expand_dims(backend.sum(x, axis=-1),axis=-1))(path23)
      # merge interpretation
      output = Add()([path10, path24])
      model = Model(inputs=input_layer, outputs=output)
      model._name = 'Recovery'
      return model
    
    neural_network_model = define_neural_network_model(input_shape)
    neural_network_model.summary()
    plot_model(neural_network_model, to_file = 'generator_model_corrected.png', show_shapes = True, show_layer_names = True)  
    
    
    Model: "Recovery"
    __________________________________________________________________________________________________
     Layer (type)                   Output Shape         Param #     Connected to                     
    ==================================================================================================
     input_12 (InputLayer)          [(None, 32767, 1)]   0           []                               
                                                                                                      
     dense_13 (Dense)               (None, 32767, 1)     2           ['input_12[0][0]']               
                                                                                                      
     re_lu_13 (ReLU)                (None, 32767, 1)     0           ['dense_13[0][0]']               
                                                                                                      
     conv1d_13 (Conv1D)             (None, 32767, 5)     20          ['re_lu_13[0][0]']               
                                                                                                      
     lambda_4 (Lambda)              (None, 32767, 1)     0           ['conv1d_13[0][0]']              
                                                                                                      
     add_11 (Add)                   (None, 32767, 1)     0           ['input_12[0][0]',               
                                                                      'lambda_4[0][0]']               
                                                                                                      
    ==================================================================================================
    Total params: 22
    Trainable params: 22
    Non-trainable params: 0
    __________________________________________________________________________________________________