I am performing a multi-index re-arrangement of a matrix based upon its correspondence data. Right now, I and doing this with a pair of index_select calls, but this is very memory inefficient (n^2 in terms of memory usage), and is not exactly ideal in terms of computation efficiency either. Is there some way that I can boil my operation down into a single .gather or .index_select call?
What I essentially want to do is when given a source array of shape (I,J,K), and an array of indices of shape (I,J,2), produce a result which meets the condition:
result[i][j][:] = source[idx[i][j][0]] [idx[i][j][1]] [:]
Here's a runnable toy example of how I'm doing things right now:
source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
indices = torch.tensor([[[2,2],[3,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]]])
ax1 = torch.index_select(source,0,indices[:,:,0].flatten())
ax2 = torch.index_select(ax1, 1, indices[:,:,1].flatten())
result = ax2.diagonal().reshape(indices.shape(0), indices.shape(1))
This approach works for me only because my images are rather small, so they fit into memory even with the diagonalization issue. Regardless, I am producing a pretty massive amount of data that doesn't need to be. Furthermore, if K becomes large, then this issue gets worse exponentially. Perhaps I'm just missing something obvious in the documentation, but I feel like this is a problem somebody else has to have run into before that can help me out!
You already have your indices in nice form for integer array indexing so we can simply do
result = source[indices[..., 0], indices[..., 1], ...]