Search code examples
pythonsortingpytorchtensor

pytorch tensor sort rows based on column


In a 2D tensor like so

tensor([[0.8771, 0.0976, 0.8186],
        [0.7044, 0.4783, 0.0350],
        [0.4239, 0.8341, 0.3693],
        [0.5568, 0.9175, 0.0763],
        [0.0876, 0.1651, 0.2776]])

How do you sort the rows based off the values in a column? For instance if we were to sort based off the last column, I would expect the rows to be such...

tensor([[0.7044, 0.4783, 0.0350],
        [0.5568, 0.9175, 0.0763],
        [0.0876, 0.1651, 0.2776],
        [0.4239, 0.8341, 0.3693],
        [0.8771, 0.0976, 0.8186]])

Values in the last column are now in ascending order.


Solution

  • t = torch.rand(5, 3)
    COL_INDEX_TO_SORT = 2
    
    # sort() returns a tuple where first element is the sorted tensor 
    # and the second is the indices of the sorted tensor.
    # The [1] at the end is used to select the second element - the sorted indices.
    sorted_indices = t[:, COL_INDEX_TO_SORT].sort()[1] 
    t = t[sorted_indices]