Search code examples
kerasneural-networklstmplaidml

Issue passing concatenated inputs to LSTM in keras


I have several neural networks. Their outputs are concatenated and then passed to LSTM.

Here is a simplified code snippet:

import keras.backend as K

from keras.layers import Input, Dense, LSTM, concatenate
from keras.models import Model

# 1st NN
input_l1 = Input(shape=(1, ))
out_l1 = Dense(1)(input_l1)

# 2nd NN
input_l2 = Input(shape=(1, ))
out_l2 = Dense(1)(input_l2)

# concatenated layer
concat_vec = concatenate([out_l1, out_l2])

# expand dimensions to (None, 2, 1)
expanded_concat = K.expand_dims(concat_vec, axis=2)

lstm_out = LSTM(10)(expanded_concat)

model = keras.Model(inputs=[input_l1, input_l2], outputs=lstm_out)

Unfortunately I get an error on the last line:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-53-a16fe60c0fc3> in <module>
      2 lstm_out = LSTM(10)(expanded_concat)
      3 
----> 4 model = keras.Model(inputs=[input_l1, input_l2], outputs=lstm_out)

/usr/local/lib/python3.9/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in __init__(self, *args, **kwargs)
     91                 'inputs' in kwargs and 'outputs' in kwargs):
     92             # Graph network
---> 93             self._init_graph_network(*args, **kwargs)
     94         else:
     95             # Subclassed network

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in _init_graph_network(self, inputs, outputs, name)
    228 
    229         # Keep track of the network's nodes and layers.
--> 230         nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network(
    231             self.inputs, self.outputs)
    232         self._network_nodes = nodes

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in _map_graph_network(inputs, outputs)
   1361     for x in outputs:
   1362         layer, node_index, tensor_index = x._keras_history
-> 1363         build_map(x, finished_nodes, nodes_in_progress,
   1364                   layer=layer,
   1365                   node_index=node_index,

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1350             node_index = node.node_indices[i]
   1351             tensor_index = node.tensor_indices[i]
-> 1352             build_map(x, finished_nodes, nodes_in_progress, layer,
   1353                       node_index, tensor_index)
   1354 

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1323             ValueError: if a cycle is detected.
   1324         """
-> 1325         node = layer._inbound_nodes[node_index]
   1326 
   1327         # Prevent cycles.

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

Is there a way to fix it? If it is important I use PlaidML backend as the only option for macOS with discrete GPU support.


Solution

  • To achieve the goal here you can use Reshape layer, that convert input into the target shape.

    Keras is integrated with Tensorflow. Here is the working code in Tensorflow version.

    import tensorflow as tf
    from tensorflow.keras.layers import Input, Dense, LSTM, concatenate
    from tensorflow.keras.models import Model
    
    # 1st NN
    input_l1 = Input(shape=(1, ))
    out_l1 = Dense(1)(input_l1)
    
    # 2nd NN
    input_l2 = Input(shape=(1, ))
    out_l2 = Dense(1)(input_l2)
    
    # concatenated layer
    concat_vec = concatenate([out_l1, out_l2])
    
    # expand dimensions to (None, 2, 1)
    expanded_concat = tf.keras.layers.Reshape((2, 1))(concat_vec)
    #expanded_concat = K.expand_dims(concat_vec, axis=2)
    
    lstm_out = LSTM(10)(expanded_concat)
    
    model = Model(inputs=[input_l1, input_l2], outputs=lstm_out)
    model.summary()
    

    Output:

    Model: "model"
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_1 (InputLayer)            [(None, 1)]          0                                            
    __________________________________________________________________________________________________
    input_2 (InputLayer)            [(None, 1)]          0                                            
    __________________________________________________________________________________________________
    dense (Dense)                   (None, 1)            2           input_1[0][0]                    
    __________________________________________________________________________________________________
    dense_1 (Dense)                 (None, 1)            2           input_2[0][0]                    
    __________________________________________________________________________________________________
    concatenate (Concatenate)       (None, 2)            0           dense[0][0]                      
                                                                     dense_1[0][0]                    
    __________________________________________________________________________________________________
    reshape_1 (Reshape)             (None, 2, 1)         0           concatenate[0][0]                
    __________________________________________________________________________________________________
    lstm (LSTM)                     (None, 10)           480         reshape_1[0][0]                  
    ==================================================================================================
    Total params: 484
    Trainable params: 484
    Non-trainable params: 0
    __________________________________________________________________________________________________