Search code examples
pythontensorflowtensorxorcustom-function

Custom function with multiple argument and one return value in map_fn for tensor object in Tensorflow


I have two tensors t1 and t2 (shape=(64,64,3), dtype=tf.float64). I want to execute a custom function "func" which takes two tensors as input and returns one new tensor.

@tf.function
def func(a):
  t1 = a[0]
  t2 = a[1]
  
  return tf.add(t1, t2)

I am using map_fn of tensorflow to execute the function for each element of the inputs.

t = tf.map_fn(fn=func, elems=(t1, t2), dtype=(tf.float64, tf.float64))
tf.print(t)

Sample input tensors for testing purpose are,

t1 = tf.constant([[1.1, 2.2, 3.3],
                  [4.4, 5.5, 6.6]])
t2 = tf.constant([[7.7, 8.8, 9.9],
                  [10.1, 11.11, 12.12]])

I cannot use map_fn with two arguments. [Tried with tf.stack, unstack also, but that didn't also work.] Any idea how to do that?


Solution

  • The "elems" parameter of "map_fn" unpacks the argument passed to it along axis 0. So, in order to pass multiple tensors in the custom function,

    1. We have to stack them together.
    2. Add an extra dimension along axis 0.
    # t1 and t2 has shape [2, 3]
    val = tf.stack([t1, t2]) # shape is now [2, 2, 3]
    val = tf.expand_dims(val, axis=0) # shape is now [1, 2, 2, 3]
    t = tf.map_fn(fn=func, elems=val, dtype=tf.float64)
    

    Also the "dtype" of "map_fn" should be the return type of the function. For example, in this case it should be tf.float64. If the function would return a tuple, the dtype would also be a tuple.

    @tf.function
    def func(a): # a has shape [2, 2, 3]
      t1 = a[0] # shape [2, 3]
      t2 = a[1] # shape [2, 3]
    
      return tf.add(t1, t2)