With even very simple example, backward()
cannot work if sparse_grad=True
, please see the error below.
Is this error expected, or I'm using gather
in a wrong way?
In [1]: import torch as th
In [2]: x = th.rand((3,3), requires_grad=True)
# sparse_grad = False, the backward could work as expetecd
In [3]: th.gather(x @ x, 1, th.LongTensor([[0], [1]]), sparse_grad=False).sum().backward()
# sparse_grad = True, backward CANNOT work
In [4]: th.gather(x @ x, 1, th.LongTensor([[0], [1]]), sparse_grad=True).sum().backward()
RuntimeError Traceback (most recent call last)
----> 1 th.gather(x @ x, 1, th.LongTensor([[0], [1]]), sparse_grad=True).sum().backward()
~/miniconda3/lib/python3.9/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
305 create_graph=create_graph,
306 inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
309 def register_hook(self, hook):
~/miniconda3/lib/python3.9/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
152 retain_graph = create_graph
--> 154 Variable._execution_engine.run_backward(
155 tensors, grad_tensors_, retain_graph, create_graph, inputs,
156 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: sparse tensors do not have strides
I think torch.gather
does not support sparse operators:
torch.gather(x, 1, torch.LongTensor([[0], [1]]).to_sparse())
Results with:
NotImplementedError: Could not run 'aten::gather.out' with arguments from the 'SparseCPU' backend.
I think you should open an issue or a feature request on pytorch's github.