Search code examples
pythonpytorch

Pytorch's .backward() stops python with no error


When doing .backwards() in the code attached python simply just stops without printing any error trace in the cli, what could be going wrong here?

output = F.softmax(output, dim=1)
argmax_values = output.max(dim=-1, keepdim=True)[1]
model.zero_grad(set_to_none=True)
print(output, argmax_values)
torch.gather(output, -1, argmax_values).backward(gradient=torch.ones_like(argmax_values)) #Python stops here 

Torch version : '1.9.0+cu111'

I tried saving the output of torch.gather in its own variable and then do .backward() after to make sure that is failing on .backward() and it is.


Solution

  • Updating to torch version 1.11.0+cu113 solved this issue for me. I'm still left without an explanation to why it did not work on 1.9.0+cu111, regardless I guess I would consider this case closed for me.