torch argsort returns indexes inside the row or a column depending on dim=...
How can I instead get a 2D indexes ..f.e..
[[ r1,r2,r3,...], [c1,c2,c3,.....]]
thanks... this is what I did
#https://github.com/pytorch/pytorch/issues/35674
def unravel_indices(indices, shape):
coord = []
for dim in reversed(shape):
coord.append(torch.fmod(indices, dim))
indices = torch.div(indices, dim, rounding_mode='floor')
coord = torch.stack(coord[::-1], dim=-1)
return coord
torch.unravel_indices = unravel_indices
Although numpy
has unravel_index
to perform this, there's no built-in for Torch but you can do it yourself. Easy enough for two dimensions:
yy, xx = indices // width, indices % width
fwiw pytorch has had a function request and PR(s) floating around for several years now.