Search code examples
pytorchparallel-processing

Efficient processing of many small torch.nn module


I am looking for a more efficient implementation of the following method:

function_list = [torch.sin, torch.exp, torch.tanh, etc.] #length: batchsize
function_choice = [0,1,2,3,2, etc.] #length: batchsize

def weird_function(x):
    # x have shape [1, dimension]
    y = torch.zeros(batchsize,dimension)
    for i in range(batchsize):
        y[1,:] = function_list[function_choice[i]](x)
    return y

In English, the first row of the returned value y is some function applied to the input x, and another row of y is another function applied to the input x.

My problem is that if I just write the code like this, the program is slow when batchsize is large (say, several hundred). Is there a better way to achieve the same purpose? Thank you for your help in advance.


Solution

  • I haven't had the time to check your problems carefully, but this seem like it can be solved with Pytorch's vmap.