Search code examples
pytorch

`a_{ij}+b_{kj}=c_{ik}` contraction in torch


Title is self explanatory. I need to work out this weird contraction in torch. Adding a row vector and column vector can do a_i+b_j=c_{ij}, which is close, but I haven't been able to generalize it. This site doesn't like short questions so enjoy this useless sentence.


Solution

  • Creating empty dimension in a tensor summation auto-copies the value across the dimension on other tensor. Here is the contraction of summation you probably want.

    (A.unsqueeze(0) + B.unsqueeze(1)).sum(-1)
    

    Update: I misunderstood your question in my previous einsum answer.