Search code examples
pythonpytorchlstm

truncated bptt pytorch implementation question


i'm trying to implement tbptt in pytorch.

I've found an implementation below in a forum and I get the logic behind the code but I keep getting a "inplace operation" error.

class TBPTT():
    def __init__(self, one_step_module, loss_module, k1, k2, optimizer):
        self.one_step_module = one_step_module
        self.loss_module = loss_module
        self.k1 = k1
        self.k2 = k2
        self.retain_graph = k1 < k2
        # You can also remove all the optimizer code here, and the
        # train function will just accumulate all the gradients in
        # one_step_module parameters
        self.optimizer = optimizer

    def train(self, input_sequence, init_state):
        states = [(None, init_state)]
        for j, (inp, target) in enumerate(input_sequence):

            state = states[-1][1].detach()
            state.requires_grad=True
            output, new_state = self.one_step_module(inp, state)
            states.append((state, new_state))

            while len(states) > self.k2:
                # Delete stuff that is too old
                del states[0]

            if (j+1)%self.k1 == 0:
                loss = self.loss_module(output, target)

                optimizer.zero_grad()
                # backprop last module (keep graph only if they ever overlap)
                start = time.time()
                loss.backward(retain_graph=self.retain_graph)
                for i in range(self.k2-1):
                    # if we get all the way back to the "init_state", stop
                    if states[-i-2][0] is None:
                        break
                    curr_grad = states[-i-1][0].grad
                    states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
                print("bw: {}".format(time.time()-start))
                optimizer.step()



seq_len = 20
layer_size = 50

idx = 0

class MyMod(nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.lin = nn.Linear(2*layer_size, 2*layer_size)

    def forward(self, inp, state):
        global idx
        full_out = self.lin(torch.cat([inp, state], 1))
        # out, new_state = full_out.chunk(2, dim=1)
        out = full_out.narrow(1, 0, layer_size)
        new_state = full_out.narrow(1, layer_size, layer_size)
        def get_pr(idx_val):
            def pr(*args):
                print("doing backward {}".format(idx_val))
            return pr
        new_state.register_hook(get_pr(idx))
        out.register_hook(get_pr(idx))
        print("doing fw {}".format(idx))
        idx += 1
        return out, new_state


one_step_module = MyMod()
loss_module = nn.MSELoss()
input_sequence = [(torch.rand(200, layer_size), torch.rand(200, layer_size))] * seq_len

optimizer = torch.optim.SGD(one_step_module.parameters(), lr=1e-3)

runner = TBPTT(one_step_module, loss_module, 5, 7, optimizer)

runner.train(input_sequence, torch.zeros(200, layer_size))
print("done")

Here is the weird thing. When I tried to run the code the first time, I kept getting another error and after a thorough speculation I found that in some of the variables such as "one_step_module", "input_sequence" where shadowing other variables outer scope. so after renaming those variables the code ran just fine. And then, I tried to revise the code a bit further for my own project, I started getting the "inplace operation" error. So, in order to see what went wrong, I fixed the code back to the original code above but I kept getting the error.. I even tried open a new file and copy paste the implementation right from the beginning, and I still can't get the code to run. This is driving me CRAZY.

Here's the "inplace operation" error I started getting from the implementation above.

C:\Users\bboyj\anaconda3\envs\jinkyu\python.exe C:/Users/bboyj/PycharmProjects/pythonProject/test1.py
doing fw 0
doing fw 1
doing fw 2
doing fw 3
doing fw 4
doing backward 4
doing backward 3
doing backward 2
doing backward 1
doing backward 0
bw: 0.17385029792785645
doing fw 5
doing fw 6
doing fw 7
doing fw 8
doing fw 9
doing backward 9
doing backward 8
doing backward 7
doing backward 6
doing backward 5
doing backward 4
C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\autograd\__init__.py:130: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  ..\c10\cuda\CUDAFunctions.cpp:100.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 100, in <module>
    runner.train(input_sequence, torch.zeros(200, layer_size))
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 59, in train
    states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\autograd\__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 100]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Process finished with exit code 1

just in case you want to see the specific code triggering the error. Here's the error log with torch anomaly detection.

C:\Users\bboyj\anaconda3\envs\jinkyu\python.exe C:/Users/bboyj/PycharmProjects/pythonProject/test1.py
doing fw 0
doing fw 1
doing fw 2
doing fw 3
doing fw 4
doing backward 4
doing backward 3
doing backward 2
doing backward 1
doing backward 0
bw: 0.17083358764648438
doing fw 5
doing fw 6
doing fw 7
doing fw 8
doing fw 9
doing backward 9
doing backward 8
doing backward 7
doing backward 6
doing backward 5
doing backward 4
C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\autograd\__init__.py:130: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  ..\c10\cuda\CUDAFunctions.cpp:100.)
  Variable._execution_engine.run_backward(
C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\autograd\__init__.py:130: UserWarning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 101, in <module>
    runner.train(input_sequence, torch.zeros(200, layer_size))
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 41, in train
    output, new_state = self.one_step_module(inp, state)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 78, in forward
    full_out = self.lin(torch.cat([inp111, state111], 1))
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\nn\functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
 (Triggered internally at  ..\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 101, in <module>
    runner.train(input_sequence, torch.zeros(200, layer_size))
  File "C:/Users/bboyj/PycharmProjects/pythonProject/test1.py", line 60, in train
    states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Users\bboyj\anaconda3\envs\jinkyu\lib\site-packages\torch\autograd\__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 100]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The main problem is that the first iteration is fine because the loss is calculated only with the new hidden state and the "detached and required_grad = True" state, but the second iteration when it tries to backward on the previous set of hidden states which already have "backwarded" it raises error. So in this case, after forwarding and backward on t =0,1,2,3,4 and forwarding on t = 5,6,7,8,9, when it tries to backward on t=9,8,7,6,5,4,3 (because k2 is 7), backward works fine for t=9,8,7,6,5 but fails on t = 4. Can anyone please shed some light on this??


Solution

  • After a careful speculation of the code, I've tracked the bug down. The problem was that after "backwarding" on the previous hidden states, the optimizer was trying to step on the hidden states that had been calculated already. I moved the optimizer out of scope of the for loop and every thing works fine!

    I'm leaving this answer for those of you who are trying to implement truncated bptt.