Search code examples
sortingindexingpytorch2d

How to make argsort return 2D indexes


torch argsort returns indexes inside the row or a column depending on dim=...

How can I instead get a 2D indexes ..f.e..

[[ r1,r2,r3,...], [c1,c2,c3,.....]]

thanks... this is what I did

#https://github.com/pytorch/pytorch/issues/35674
def unravel_indices(indices, shape):
    coord = []

    for dim in reversed(shape):
        coord.append(torch.fmod(indices, dim))
        indices = torch.div(indices, dim, rounding_mode='floor')

    coord = torch.stack(coord[::-1], dim=-1)

    return coord
torch.unravel_indices = unravel_indices

Solution

  • Although numpy has unravel_index to perform this, there's no built-in for Torch but you can do it yourself. Easy enough for two dimensions:

    yy, xx = indices // width, indices % width
    

    fwiw pytorch has had a function request and PR(s) floating around for several years now.