Search code examples
pytorchbackpropagation

Can the following Pytorch operations be backwarded?


I have a tensor A, which is from original point cloud data. Its size is (N,3). Besides, I have a tensor B.It is an output score tensor by a neural network.Its size is (N,1). I firstly use torch.cat to cat A and B on the 1 dim. Then I use torch.argsort to order the catted tensor according to the value of the last column and get the indices. Then I used torch.gather to sort the catted vector. After the catted vector is sorted, I will use the first 10 rows of the catted tensor to calculate loss. When I calculate loss, I will use the first three columns in the first 10 rows, which is from the original point cloud data.

Can the process be backwarded(the gradients are not 0.) by Pytorch?If the answer is no, please tell me how to slove the problem? Thanks.


Solution

  • After the catted vector is sorted, I will use the first 10 rows of the catted tensor to calculate loss.

    You need to use the neural network's output to calculate loss, not the indices it produces with argsort. The operation that extracts indices from the scores will lose the gradients because such operation is not differentiable.

    import torch
    out = torch.randn(10, 100).requires_grad_(True)
    t_v, t_i = torch.topk(out, 10, dim=-1)
    print(t_v.requires_grad) # prints True
    print(t_i.requires_grad) # prints False
    

    Simply, you have to figure out a way to compute the loss from t_v to be able to propagate.

    This post on soft topk may be of interest to you.