Search code examples
pythondeep-learningpytorchgradient-descent

In PyTorch, how do I update a neural network via the average gradient from a list of losses?


I have a toy reinforcement learning project based on the REINFORCE algorithm (here's PyTorch's implementation) that I would like to add batch updates to. In RL, the "target" can only be created after a "prediction" has been made, so standard batching techniques do not apply. As such, I accrue losses for each episode and append them to a list l_losses where each item is a zero-dimensional tensor. I hold off on calling .backward() or optimizer.step() until a certain number of episodes have passed in order to create a sort of pseudo batch.

Given this list of losses, how do I have PyTorch update the network based on their average gradient? Or would updating based on the average gradient be the same as updating on the average loss (I seem to have read otherwise elsewhere)?

My current method is to create a new tensor t_loss from torch.stack(l_losses), and then run t_loss = t_loss.mean(), t_loss.backward(), optimizer.step(), and zero the gradient, but I'm unsure if this is equivalent to my intents? It's also unclear to me if I should have been running .backward() on each individual loss instead of concatenating them in a list (but holding on the .step() part until the end?


Solution

  • Gradient is a linear operation so gradient of the average is the same as the average of the gradient.

    Take some example data

    import torch
    a = torch.randn(1, 4, requires_grad=True);
    b = torch.randn(5, 4);
    

    You could store all the losses and compute the mean as you are doing,

    a.grad = None
    x = (a * b).mean(axis=1)
    x.mean().backward() # gradient of the mean
    print(a.grad)
    

    Or every iteration to compute the back propagation to get the contribution of that loss to the gradient.

    a.grad = None
    for bi in b:
        (a * bi / len(b)).mean().backward()
    print(a.grad)
    

    Performance

    I don't know the internal details of the pytorch backward implementation, but I can tell that

    (1) the graph is destroyed by default after the backward pass ratain_graph=True or create_graph=True to backward().

    (2) The gradient is not kept except for leaf tensors, unless you specify retain_grad;

    (3) if you evaluate a model twice using different inputs, you can perform the backward pass to individual variables, this means that they have separate graphs. This can be verified with the following code.

    a.grad = None
    # compute all the variables in advance
    r = [ (a * b / len(b)).mean() for bi in b ]
    for ri in r:
        # This depends on the graph of r[i] but the graph or r[i-1]
        # was already destroyed, it means that r[i] graph is independent
        # of r[i-1] graph, hence they require separate memory.
        ri.backward()  # this will remove the graph of ri
    print(a.grad)
    

    So if you update the gradient after each episode it will accumulate the gradient of the leaf nodes, that's all the information you need for the next optimization step, so you can discard that loss freeing up resources for further computations. I would expect a memory usage reduction, potentially even a faster execution if the memory allocation can efficiently use the just deallocated pages for the next allocation.