My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset
and tf.function
API.
Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function
so this limits the use of packages like numpy
and list indexing.
For example, I would like to do something like this:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
However I cannot index the list_of_funcs
as TypeError: list indices must be integers or slices, not Tensor
. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor
and use tf.gather
.
So my question: how can I reasonably and neatly sample from these functions in a tf.function
?
You can use
tf.switch_case
like
def func1(tensor):
return tensor * 1
def func2(tensor):
return tensor * 2
def func24(tensor):
return tensor * 24
class Lambda:
def __init__(self, func, arg):
self._func = func
self._arg = arg
def __call__(self):
return self._func(self._arg)
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, func24]
branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
output = tf.switch_case(
branch_index=branch_index,
branch_fns=[Lambda(func, tensor) for func in list_of_funcs],
)
return output
Decorator @tf.function
is needed only for entire function you wish to optimize that is apply
in this case. If you use apply
inside tf.data.Dataset.map
the decorator is not needed at all.
See
this discussion
to understand why we have to define class Lambda
here.