I want to use a Lambda layer to retrieve values from a fixed array. Here is my toy example:
import tensorflow as tf
fixed_array = tf.random.uniform( shape=(5,32) )
index_input = tf.keras.Input( shape=(1,), dtype='int32' )
output = tf.keras.layers.Lambda( lambda x: fixed_array[ x[:,0] ] )( index_input )
model = tf.keras.Model( inputs=index_input, outputs=output )
model.compile()
but when run it, I get the following error:
output = tf.keras.layers.Lambda( lambda x: fixed_array[ x ] )( index_input )
ValueError: Exception encountered when calling layer "lambda" (type Lambda).
Shape must be rank 1 but is rank 3 for '{{node lambda/strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](lambda/6, lambda/strided_slice/stack, lambda/strided_slice/stack_1, lambda/strided_slice/stack_2)' with input shapes: [5,32], [1,?,1], [1,?,1], [1].
Call arguments received by layer "lambda" (type Lambda):
• inputs=tf.Tensor(shape=(None, 1), dtype=int32)
• mask=None
• training=None
I do not understand the error message: First it notes that the rank is 3 (Shape must be rank 1 but is rank 3), but then it writes the rank is 1 (inputs=tf.Tensor(shape=(None, 1), dtype=int32)).
What is the error message telling me, and how do I fix this exmaple.
The error message is very cryptic indeed, it does not make too much sense to me either. However, the Numpy style indexing compiles to tf.getitem
, which only support basic indexing not including indexing by tensors, while your x
variable is a tensor. If you define your lambda function explicitly and print x
, it will show:
Tensor("Placeholder:0", shape=(None, 1), dtype=int32)
It seems however, that you can use tf.gather
to handle tensors of indices, like this:
import tensorflow as tf
fixed_array = tf.random.uniform( shape=(5,32) )
index_input = tf.keras.Input( shape=(1,), dtype='int32' )
output = tf.keras.layers.Lambda( lambda x: tf.gather(fixed_array, x[:,0]) )( index_input )
model = tf.keras.Model( inputs=index_input, outputs=output )
model.compile(loss=tf.keras.losses.MeanSquaredError()) # you have to specify a loss function
and then you can run for example:
model.predict(tf.constant([0,2,4]))
which will return the specified rows of fixed_array
.