Search code examples
pythondeep-learningpytorchlstmpytorch-lightning

How to run LSTM on very long sequence using Truncated Backpropagation in Pytorch (lightning)?


I have a very long time series I want to feed into an LSTM for classification per-frame.

My data is labeled per frame, and I know some rare events happen that influence the classification heavily ever since they occur.

Thus, I have to feed the entire sequence to get meaningful predictions.

It is known that just feeding very long sequences into LSTM is sub-optimal, since the gradients vanish or explode just like normal RNNs.


I wanted to use a simple technique of cutting the sequence to shorter (say, 100-long) sequences, and run the LSTM on each, then pass the final LSTM hidden and cell states as the start hidden and cell state of the next forward pass.

Here is an example I found of someone who did just that. There it is called "Truncated Back propagation through time". I was not able to make the same work for me.


My attempt in Pytorch lightning (stripped of irrelevant parts):

def __init__(self, config, n_classes, datamodule):
    ...
    self._criterion = nn.CrossEntropyLoss(
        reduction='mean',
    )

    num_layers = 1
    hidden_size = 50
    batch_size=1

    self._lstm1 = nn.LSTM(input_size=len(self._in_features), hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self._log_probs = nn.Linear(hidden_size, self._n_predicted_classes)
    self._last_h_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)
    self._last_c_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)

def training_step(self, batch, batch_index):
    orig_batch, label_batch = batch
    n_labels_in_batch = np.prod(label_batch.shape)
    lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch, (self._last_h_n, self._last_c_n))
    log_probs = self._log_probs(lstm_out)
    loss = self._criterion(log_probs.view(n_labels_in_batch, -1), label_batch.view(n_labels_in_batch))

    return loss

Running this code gives the following error:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

The same happens if I add

def on_after_backward(self) -> None:
    self._last_h_n.detach()
    self._last_c_n.detach()

The error does not happen if I use

lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch,)

But obviously this is useless, as the output from the current frame-batch is not forwarded to the next one.


What is causing this error? I thought detaching the output h_n and c_n should be enough.

How do I pass the output of a previous frame-batch to the next one and have torch back propagate each frame batch separately?


Solution

  • Apparently, I missed the trailing _ for detach():

    Using

    def on_after_backward(self) -> None:
        self._last_h_n.detach_()
        self._last_c_n.detach_()
    

    works.


    The problem was self._last_h_n.detach() does not update the reference to the new memory allocated by detach(), thus the graph is still de-referencing the old variable which backprop went through. The reference answer solved that by H = H.detach().

    Cleaner (and probably faster) is self._last_h_n.detach_() which does the operation in place.