Search code examples
pythonarraysnumpymultidimensional-arraypytorch

Sort a multi-dimensional tensor using another tensor


I'm trying to sort C (see image) using R to get Sorted_C.

enter image description here

c = torch.tensor([[[0, 1, 0, 0, 0], [1, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], [[0, 0, 1, 1, 0], [1, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0]]])
r = torch.tensor([[[0, 0, 0, 7.2, 0], [0, 25.4, 0, 0, 0], [0, 0, 43.6, 0, 0], [61.8, 0, 0, 0, 0], [0, 0, 0, 0, 80]], [[0, 0, 0, 0, 98.2], [116.4, 0, 0, 0, 0], [0, 134.6, 0, 0, 0], [0, 0, 152.8, 0, 0], [0, 0, 0, 169.2, 0]]])

# this is what I need
sorted_c = torch.tensor([[[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1], [0, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 1, 0, 1]]])

How do I do this efficiently?

correction: expected ordering for P6 to P10 should be: P10 -> P6 -> P7 -> P8 -> P9


Solution

  • Minor note: it seems your example r doesn't quite match the values in your drawing. Also, it seems the result for p6 through p10 should be p10 < p6 < p7 < p8 < p9.


    When you hear "advanced indexing", you should think torch.gather. That is: the resulting tensor comes from the indexing of c column-wise with some kind of tensor we will extract from r.

    First we can sum and sort r to get the indices of the columns:

    >>> idx = r.sum(1).argsort(1)
    tensor([[3, 1, 2, 0, 4],
            [4, 0, 1, 2, 3]])
    

    Then we can apply torch.Tensor.gather indexing c column-wise using the column indices contained in idx i.e. dim=2 is the one varying based on values in idx. Explicitly the resulting tensor out is constructed such that:

    out[i][j][k] = c[i][j][idx[i][j][k]]
    

    Keep in mind both the index tensor (idx) and the value tensor (c) must have same dimension sizes except for the one that we're indexing on, here dim=2. Therefore, we need to expand idx such that it has the shape of c. We can do so with None-indexing and using expand or expand_as:

    >>> idx[:,None].expand_as(c)
    tensor([[[3, 1, 2, 0, 4],
             [3, 1, 2, 0, 4],
             [3, 1, 2, 0, 4],
             [3, 1, 2, 0, 4]],
    
            [[4, 0, 1, 2, 3],
             [4, 0, 1, 2, 3],
             [4, 0, 1, 2, 3],
             [4, 0, 1, 2, 3]]])
    

    Notice the duplicated values row-wise (fiy: they're not copies, expand is makes a view not a copy!)

    Finally, we can gather the values in c to get the desired result:

    >>> c.gather(2, idx[:,None].expand_as(c))
    tensor([[[0, 1, 0, 0, 0],
             [0, 0, 1, 1, 0],
             [1, 0, 0, 0, 0],
             [0, 0, 0, 0, 1]],
    
            [[0, 0, 0, 1, 1],
             [0, 1, 1, 0, 0],
             [0, 0, 0, 1, 0],
             [0, 0, 1, 0, 1]]])