Search code examples
pythonpytorch

Apply a generic function to an all-to-all couple of tensors


I have a question similar to this one, using a generic function of the kind:

def f(a, b):
   return Something

This something can be a scalar, or a tensor. The application of f should be all-to-all from the rows of two matrices that are possibly aliased like:

def apply_very_slow(source1, source2):
    s = torch.tensor(...) # compute the right dimensions
    for i in range(k):
        for j in range(k):
            p = f(source1[i], source2[j])
            s[i,j] = p

Of course this is immensely slow. I could be calling apply_very_slow with the same argument apply_very_slow(M, M) or different tensors apply_very_slow(M, N), but I have no guarantee that in this case M and N point at the same data.

Moreover, I'd like the results to be autograd-friendly. No problem if f must be converted to a class. In the questions I've linked they use torch.nn.functional.kl_div, but I need a solution to be generic.

Any hints?


Solution

  • You are looking to apply f in an outer fashion, ie. of the form i,j->ij. You can do so with broadcasting by unsqueezing dimensions on both inputs:

    >>> M = torch.rand(10, 3)
    >>> N = torch.rand(5, 3)
    

    Here are some example functions applied (F.kl_div, F.l1_loss, and F.pairwise_distance):

     >>> y = F.kl_div(M[:,None], N[None], reduction='none')
     # shape of (10, 5, 3) last dim not reduced
    
     >>> y = F.l1_loss(M[:,None], N[None], reduction='none')
     # shape of (10, 5, 3) last dim not reduced
    
     >>> torch.pairwise_distance(M[:,None], N[None])
     # shape of (10, 5) last dim reduced