Search code examples
pythontensorflowkeras

How to concatenate tensors after expanding dimension using Keras Functional API?


I am trying to expand dimension:

import tensorflow as tf
inp = tf.keras.layers.Input(shape=(1,))
inp = inp[..., tf.newaxis]
decoder_input = inp
output = tf.concat([inp, decoder_input], 1)
model = tf.keras.models.Model(inp, output )

But I get an error in the last line:

Exception has occurred: ValueError Graph disconnected: cannot obtain value for tensor Tensor("input_1:0", shape=(None, 1), dtype=float32) at layer "tf_op_layer_strided_slice". The following previous layers were accessed without issue: []


Solution

  • Is this what you are trying to do? Seems you have a variable conflict. You are setting the decoder_input as a reshape layer instead of an input layer. Changing the name of the reshape layer fixes it.

    import tensorflow as tf
    inp = tf.keras.layers.Input(shape=(1,))
    
    x = tf.keras.layers.Reshape((-1,1))(inp) #Use any of the 3
    #x = tf.expand_dims(inp, axis=-1)
    #x = inp[...,tf.newaxis]
    
    decoder_input = inp
    output = tf.concat([inp, decoder_input], 1)
    model = tf.keras.models.Model(inp, output)
    
    model.summary()
    
    Model: "functional_8"
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_7 (InputLayer)            [(None, 1)]          0                                            
    __________________________________________________________________________________________________
    tf_op_layer_concat_4 (TensorFlo [(None, 2)]          0           input_7[0][0]                    
                                                                     input_7[0][0]                    
    ==================================================================================================
    Total params: 0
    Trainable params: 0
    Non-trainable params: 0