Search code examples
pythonpytorch

Is it possible to subset a PyTorch tensor using a tensor index without causing a backward second-time error?


I am trying to write a Pytorch program based on computing embeddings, where during training various subsets of the embeddings are used. I want to subset the embedding tensor using a tensor index, but when I do this I get the typical backward second time error. However, performing the exact same operation using a slice, instead of an index, runs fine. In my actual code, I have confirmed that the slice version works as intended and trains properly, etc., but new features cannot be done with only slices.

I am running python 3.10.9, pytorch 2.0.0 with cuda 11.8 on Windows.

Here is a minimum example that reproduces the error I am running into:

import torch
device = 'cpu'    # same error occurs when using device='cuda:0'
embeddings = torch.tensor(torch.randn([1024, 128], device=device, dtype=torch.float32), requires_grad=True)
target = torch.rand([1024, 128], device=device)
optimizer = torch.optim.Adadelta([embeddings], lr=1.0)

# alternate approach using nn.Embedding doesn't work either
embeddings_2 = torch.nn.Embedding(1024, 128, device=device)

# OPTION 1: This works
# cur_embeddings = embeddings[:512, :]

# OPTION 2: This is what I want to do, but it doesn't work
cur_embeddings = embeddings[torch.arange(512, device=device), :]

# OPTION 3: The following option does not work either
# cur_embeddings = embeddings_2(torch.arange(512, device=device))

cur_target = target[:512, :]
for idx in range(2):
    optimizer.zero_grad()
    loss = torch.nn.MSELoss()(cur_embeddings, cur_target)
    loss.backward()
    optimizer.step()

Of the 3 presented options, option 1 works but option 2 produces an error, even though it is functionally identical. How can I make option 2 work exactly like option 1 (while maintaining the freedom to subset arbitrary rows)?

I also show here that using nn.Embedding (which I don't really want to use but could if it works) still produces the error.


Solution

  • Solution: Put this line inside your training loop:

    cur_embeddings = embeddings[torch.arange(512, device=device), :]
    

    Explanation: You computation graph looks like

    embeddings ---> cur_embeddings ---> loss
                |                   |
        torch.arange(512)       cur_target
    

    To backward propagate to embeddings, pytorch must remember how you indexed embeddings to get cur_embeddings, that is, pytorch needs to save torch.arange(512) for backward. However, by putting this indexing operation out of the loop, torch.arange(512) is freed after the first backward, so before the second backward pytorch had already "forgotten" how that indexing is done.

    Your options 1 works because indexing is hardcoded as embeddings[:512].