Search code examples
pythontensorflowtensorflow2

Select an item from a list of object of any type when using tensorflow 2.x


Given a list of instances of class A, [A() for _ in range(5)], I want to randomly select one of them (see the following code for an example)

class A:
    def __init__(self, a):
        self.a = a
    def __call__(self):
        return self.a
def f():
    a_list = [A(i) for i in range(5)]
    a = a_list[random.randint(0, 5)]()
    return a

f()

Is there is a way to decorate f with @tf.function without changing what f does and without calling all items in a_list?

Note that directly decorating f with @tf.function without any other changing to the above code is infeasible as it will always return the same result. Also, I know that this can be achieved by calling all elements in a_list first and then index them using tf.gather_nd. But this will incur a large amount of overhead if calling an object of type A involves a deep neural network.


Solution

  • I'm working on the same thing at the moment. Here's what I've got so far. If anyone knows a better way I'd be interested to hear it too. When I run it on an expensive call it is appropriately faster than if I compute and return all of the values.

    @tf.function
    def f2():
        a_list = [A(i) for i in range(5)]
        idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
        return tf.switch_case(idx, a_list)
    

    For a speed comparison I made the call method of A expensive matrix algebra. Then consider an alternate function which invokes every function:

    @tf.function
    def f3():
        a_list = [A(i) for i in range(40)]
        results = [a() for a in a_list]
        return results
    

    Running f2 with 40 elements: 0.42643 seconds

    Running f3 with 40 elements: 14.9153 seconds

    So that looks to be right about exactly the expected 40x speedup for only choosing one branch.