Search code examples
pythontensorflowkeraskeras-layer

Indexing a Keras Tensor


The output layer of my Keras functional model is a tensor x of dimension (None, 1344, 2). I wish to extract n < 1344 entries from within 2nd dimension of x and create a new tensor y of size (None, n, 2).

It seems straight-forward to extract n consecutive entries by simply accessing x[:, :n,:], but (seemingly) difficult if the n indices are non-consecutive. Is there a clean way in Keras to do so?

Here are my approaches so far.

Experiment 1 (Slicing a tensor, consecutive indices, works):

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
print('Slicing first 5 entries, shape is', K.int_shape(x[:, :5, :]))
(None, 5, 2) # as printed in my code, works!

Experiment 2 (Indexing tensor at arbitrary indices, fails)

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
foo = np.array([1,2,4,5,8])
print('arbitrary indexing, shape is', K.int_shape(x[:,foo,:]))

Keras returns the following error:

ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'strided_slice_17/stack_1' (op: 
'Pack') with input shapes: [], [5], [].

Experiment 3 (Tensor flow backend function) I have also tried K.backend.gather but its usage is unclear because 1) Keras documentation states that the indices should be a tensor of integers and there is no Keras equivalent of numpy.where if my goal is to extract entries in x satisfying a certain condition and 2) K.backend.gather appears to extract entries from axis = 0 whereas I want to extract from the second dimension of x.


Solution

  • You are looking for tf.gather_nd which will index based on an index array:

    # From documentation
    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']
    

    To use it in a Keras model make sure to wrap it in a Layer like Lambda.