Search code examples
pythontensorflowkerastensorarray-broadcasting

Access elements of a Tensor


I have the following TensorFlow tensors.

tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]

tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)

tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]

I wish to use the values stored in tensor 3 and tensor 4 to make a tuple and query the element at position given by the tuple in tensor 5. For example, let's say 0th element in tensor 3, that is tensor3[0]=5 and tensor4[0]=99. So the tuple becomes (5,99). I wish to look up the value of element (5,99) in tensor 5. I wish to do it for all elements in Tensor3 and Tensor4 in a batch processing manner. That is I do not want to loop over all values in the range of (len(Tensor3)). I did the following to achieve this.

tensor6 = tensor5[tensor3[0],tensor4[0]]

But tensor6 has the shape (255,255) where as I was hoping to get a tensor of shape (len(tensor3),len(tensor3)). I wanted to evaluate tensor5 at all possible locations in len(tensor3). That is at (0,0),...(1000,1000),....(2000,2000),.... I am using TensorFlow version 1.12.0. How can I achieve this?


Solution

  • I have managed to get something working in Tensorflow v 1.12, but do let me know if it is the expected code:

    import tensorflow as tf
    print(tf.__version__)
    import numpy as np
    
    tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
    tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
    
    tensor3 = tf.keras.backend.flatten(tensor1)
    tensor4 = tf.keras.backend.flatten(tensor2)
    
    tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]
    
    elems = (tensor3, tensor4)
    a = tf.map_fn(lambda x: tensor5[x[0], x[1]], elems, dtype=tf.int32)
    
    print(tf.Session().run(a))
    

    Based on the comment below I'd like to add an explanation for the map_fn used in the code. Since for loops are not supported without eager_execution, map_fn is (sort of) equivalent to for loops.

    A map_fn has the following parameters: operation_performed, input_arguments, optional_dtype. What happens under the hood is that a for loop is run along the length of the values in input_arguments (which must contain an iterable object) and then for each value obtained operation_performed is performed. For further clarification please refer docs.

    The names given to the arguments of the function is my way of interpreting them, as I'd like understand it, and is not given in the official docs. :)