Search code examples
pythontensorflowtensorflow2.0tensorflow-datasets

How to randomly select from set of functions in TensorFlow using tf.function


My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset and tf.function API.

Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function so this limits the use of packages like numpy and list indexing.

For example, I would like to do something like this:

import tensorflow as tf

@tf.function
def func1(tensor):
    # Apply some rotation here
    ...

@tf.function
def func2(tensor):
    ...

...

@tf.function
def func24(tensor):
    ...


@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, ..., func24]

    # Randomly sample from 0-23
    a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
    
    return list_of_funcs[a](tensor)

However I cannot index the list_of_funcs as TypeError: list indices must be integers or slices, not Tensor. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor and use tf.gather.

So my question: how can I reasonably and neatly sample from these functions in a tf.function?


Solution

  • You can use tf.switch_case like

    def func1(tensor):
        return tensor * 1
    
    def func2(tensor):
        return tensor * 2
    
    def func24(tensor):
        return tensor * 24
    
    class Lambda:
        def __init__(self, func, arg):
            self._func = func
            self._arg = arg
            
        def __call__(self):
            return self._func(self._arg)
    
    @tf.function
    def apply(tensor):
        list_of_funcs = [func1, func2, func24]
    
        branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
        output = tf.switch_case(
            branch_index=branch_index, 
            branch_fns=[Lambda(func, tensor) for func in list_of_funcs], 
        )
        
        return output
    
    

    Decorator @tf.function is needed only for entire function you wish to optimize that is apply in this case. If you use apply inside tf.data.Dataset.map the decorator is not needed at all.

    See this discussion to understand why we have to define class Lambda here.