Say there are two pytorch tensors a
, which is float32
with shape [M, N]
, and b
, which is int64
with shape [K]
. The values in b
are within [0, M-1], so the following line gives a new tensor c
indexed by b
:
c = a[b] # [K, N] tensor whose i-th row is a[b[i]], with `IndexBackward`
However, in a project of mine, this line always reports the following error (which is detected with torch.autograd.detect_anomaly()
:
with torch.autograd.detect_anomaly():
[W python_anomaly_mode.cpp:104] Warning: Error detected in IndexBackward. Traceback of forward call that caused the error:
...
File "/home/user/project/model/network.py", line 60, in index_points
c = a[b]
(function _print_stack)
Traceback (most recent call last):
File "main.py", line 589, in <module>
main()
File "main.py", line 439, in main
train_stats = train(
File "/home/user/project/train_eval.py", line 866, in train
total_loss.backward()
File "/home/user/.local/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/user/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward
Variable._execution_engine.run_backward(
RuntimeError: merge_sort: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered
Note that the line c = a[b]
above is not the only occurrence of said error, but just one among many other lines with square-bracket indexing.
However, the problem magically goes away when I change the indexing style from
c = a[b]
to
c = a.index_select(0, b)
I don't understand why indexing with square brackets leads to illegal memory access, but this gives me enough reason to believe square-bracket indexing and index_select
are implemented differently. Understanding that could be the key to explain this. Also, since the project is rather large and not public, I can't share the exact codes here. You can just treat things above as background and focus on how square-bracket indexing and index_select
are different. Thanks!
Additional information:
torch.index_select
returns a new tensor which copies the indexed fields into a new memory location (docs).
torch.Tensor.select
or slicing returns a view of the original tensor (docs).
Without seeing more of your code, it's hard to say why this particular difference in functionality might cause the above error.