Search code examples
pythonpytorchcomputation-graph

Find PyTorch model parameters that don't contribute to loss


In PyTorch (v1.10) Distibuted DataParallel, unused parameters in a model that don't contribute to the final loss can raise a RuntimeError (as mentioned in this other question, this PyTorch forums thread).

"RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by making sure all forward function outputs participate in calculating loss."

Although it's possible to inspect which parameters are affected at error-time (as mentioned above, or setting env var TORCH_DISTRIBUTED_DEBUG="INFO"), it seems like there should be a way to statically inspect a model to locate (and presumably prune or disable gradient on) parameters that aren't contributing to the current loss objective?

So given a torch.nn.Module-based model whose forward() function returns some loss tensor (maybe alongside others) - How can we programmatically, before starting to train, find all parameters (including nested modules) that aren't contributing to loss?


Solution

  • By default, PyTorch tensors that are the result of some computation record their history, that is their ancestors. This is needed for the backward pass to compute the gradient.

    We can make use of this to find all tensors that contribute to some new tensors by just going through the whole history.

    Note that this works for a static network that always has the same architecture. As soon as you have conditionals that e.g. depend on some intermediate value this won't work, and I claim in that case it is impossible to find what tensors are involved in advance. (It's similar to the halting problem.)

    import torch
    import torch.nn as nn
    # Example of a simple network
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.x = nn.Parameter(torch.tensor([999999.0]))  # not contributing
            self.layers = nn.ModuleList([nn.Sequential(nn.Linear(1, 4), nn.Linear(4, 1)) for _ in range(3)])
        def forward(self, x):
            for m in self.layers: x = m(x) + x
            return x
    
    net = Net()
    x = torch.ones((1, 1))
    # compute the forward pass to create the computation graph
    y = net(x)
    
    # use computation graph to find all contributing tensors
    def get_contributing_params(y, top_level=True):
        nf = y.grad_fn.next_functions if top_level else y.next_functions
        for f, _ in nf:
            try:
                yield f.variable
            except AttributeError:
                pass  # node has no tensor
            if f is not None:
                yield from get_contributing_params(f, top_level=False)
    
    contributing_parameters = set(get_contributing_params(y))
    all_parameters = set(net.parameters())
    non_contributing = all_parameters - contributing_parameters
    print(non_contributing)  # returns the [999999.0] tensor