Search code examples
pythonkeraskeras-layer

Error - input that isn't a symbolic tensor - Boolean


I am trying to adjust the weights of a Dense layer thanks to a binary crossentropy loss. A have created a shared layer that output for two vectors two values (encoded_value_1 and encoded_value_2). I want to create a boolean equal to 1 if the value of encoded_value_1 is superior to the value of encoded_value_2. In order to do that I use greater via a Lambda layer. Yet, it produces an error (see below).

import keras
from keras.backend import greater
from keras.layers import Input, LSTM, Dense, Lambda, concatenate
from keras.models import Model

value_1 = Input(shape=(4,))
value_2 = Input(shape=(4,))

shared_layer = Dense(1)
encoded_value_1 = shared_layer(value_1)
encoded_value_2 = shared_layer(value_2)

x = Lambda(greater,output_shape=(1,))((encoded_value_1,encoded_value_2)) 
model = Model(inputs=[value_1, value_2], outputs=x)
model.compile(optimizer='adam',loss='binary_crossentropy', metrics='accuracy'])

NB: I also tried to concatenate the two layers, I had the same error.

merged_vector = concatenate([encoded_value_1, encoded_value_2], axis=-1)
x = Lambda(greater,output_shape=(1,))((merged_vector[0],merged_vector[1]))

ValueError: Layer lambda_4 was called with an input that isn't a symbolic tensor. Received type: . Full input: [(, )]. All inputs to the layer should be tensors.


Solution

  • There are three points:

    1. When the Lambda layer has more than one input, then the inputs must be passed as a list of tensors, not a tuple.

    2. The output of greater is a bool Tensor which you need to cast to float for doing computations on it.

    3. greater takes two inputs, so you need to wrap it inside a python lambda function to be able to use it in a Lambda layer in Keras.

    Therefore, we would have:

    from keras import backend as K
    
    # ...
    x = Lambda(lambda z: K.cast(K.greater(z[0], z[1]), K.floatx()),output_shape=(1,))([encoded_value_1,encoded_value_2])
    

    And also don't forget the missing opening bracket for metrics argument:

    ..., metrics=['accuracy'])
                 ^
                 |
                 |
              missing!