Search code examples
tensorflow

Tensorflow: Using Lambda Layer to index into array


I want to use a Lambda layer to retrieve values from a fixed array. Here is my toy example:

import tensorflow as tf
fixed_array = tf.random.uniform( shape=(5,32) )
index_input = tf.keras.Input( shape=(1,), dtype='int32' )
output = tf.keras.layers.Lambda( lambda x: fixed_array[ x[:,0] ] )( index_input )
model = tf.keras.Model( inputs=index_input, outputs=output )
model.compile()

but when run it, I get the following error:

output = tf.keras.layers.Lambda( lambda x: fixed_array[ x ] )( index_input )
ValueError: Exception encountered when calling layer "lambda" (type Lambda).

Shape must be rank 1 but is rank 3 for '{{node lambda/strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](lambda/6, lambda/strided_slice/stack, lambda/strided_slice/stack_1, lambda/strided_slice/stack_2)' with input shapes: [5,32], [1,?,1], [1,?,1], [1].

Call arguments received by layer "lambda" (type Lambda):
  • inputs=tf.Tensor(shape=(None, 1), dtype=int32)
  • mask=None
  • training=None

I do not understand the error message: First it notes that the rank is 3 (Shape must be rank 1 but is rank 3), but then it writes the rank is 1 (inputs=tf.Tensor(shape=(None, 1), dtype=int32)).

What is the error message telling me, and how do I fix this exmaple.


Solution

  • The error message is very cryptic indeed, it does not make too much sense to me either. However, the Numpy style indexing compiles to tf.getitem, which only support basic indexing not including indexing by tensors, while your x variable is a tensor. If you define your lambda function explicitly and print x, it will show:

    Tensor("Placeholder:0", shape=(None, 1), dtype=int32)
    

    It seems however, that you can use tf.gather to handle tensors of indices, like this:

    import tensorflow as tf
    
    fixed_array = tf.random.uniform( shape=(5,32) )
    index_input = tf.keras.Input( shape=(1,), dtype='int32' )
    output = tf.keras.layers.Lambda( lambda x: tf.gather(fixed_array, x[:,0]) )( index_input )
    model = tf.keras.Model( inputs=index_input, outputs=output )
    model.compile(loss=tf.keras.losses.MeanSquaredError())  # you have to specify a loss function
    

    and then you can run for example:

    model.predict(tf.constant([0,2,4]))
    

    which will return the specified rows of fixed_array.