Search code examples
pythonimagejupyter-notebookpytorch

Sort a vector in PyTorch


I am performing a prediction using an input image and a pre-trained classifier on ImageNet using PyTorch. What I would like to do is to calculate the value for each for the class and returned the highest 10 values. My code looks like:

img = imread_img('image.png')
input = pre_processing(img) # normalize image, transpose and return a tensor

# load model
# model_type = 'vgg19'
model = models.vgg19(pretrained=True)

# run it on a GPU if available:
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('cuda:', cuda, 'device:', device)
model = model.to(device)

# set model to evaluation
model.eval()

out = model(input)     
print (out.shape)                
out = F.softmax(out, dim=1)    

out= torch.sort(out, descending=True)

# top = out[:][0] # that returns only the values and not a tuple

The function sort returns the output vector and returns values and indices. How can I keep the first 5 higher values after the sort?


Solution

  • Extract the indices and values as a tuple, and then slice the part you want

    out_sorted, indices = torch.sort(out, descending=True)
    top_values = out_sorted[:, :5] # Keep the first 5 values from each row
    top_indices = indices[:, :5]   # Keep the corresponding indices