I was trying to vectorize this piece of Python (or Matlab after some minor changes) code, which is the sum aggregation for a directed graph
for j in range(batchSize):
for i in range(2*nEdges[j]):
localSum[j,receivers[j,i],0:2]+=(localFeature[j,i,0:2])
I know how to drop one of the loops when it is =
, but I failed to find a way to do it with this +=
case. The difficulty to do the same trick lies in that receivers[j,i]
could be the same value for different j
and i
s.
Do anyone have some idea on how to do the vectorization in this case?
Thanks a lot :)
Finally found it. torch_scatter
package offers this functionality (for pytorch tensors). The nested loop can be rewritten as
localSum = torch_scatter.scatter(localFeature,receivers,dim=1,reduce='sum')