Search code examples
pythonmachine-learningdeep-learningpytorch

How to efficiently calculate gradients of all outputs with respect to parameters?


I have a relatively simple requirement but surprisingly this does not seem to be straightforward to implement in pytorch. Given a neural network with $P$ parameters that outputs a vector of length $Y$ and a batch of $B$ data inputs, I would like to calculate the gradients of the outputs with respect to the model's parameters.

In other words, I would like the following function:

def calculate_gradients(model, X):
    """
    Args:
        nn module with P parameters in total that outputs a tensor of size (B, Y).
        torch tensor of shape (B, .).

    Returns:
        torch tensor of shape (B, Y, P)
    """
    # function logic here

Unfortunately, I don't currently see an obvious way of calculating this efficiently, especially without aggregating over the data or target dimensions. A minimal working example below involves looping over input and target dimensions, but surely there is a more efficient way?

import torch
from torchvision import datasets, transforms
import torch.nn as nn

###### SETUP ######

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        h = self.fc1(x)
        pred = self.fc2(self.relu(h))
        return pred
    
train_dataset = datasets.MNIST(root='./data', train=True, download=True, 
                            transform=transforms.Compose(
                                [transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))
        ]))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)

X, y = next(iter(train_dataloader))  # take a random batch of data

net = MLP(28*28, 20, 10)  # define a network


###### CALCULATE GRADIENTS ######
def calculate_gradients(model, X):
    # Create a tensor to hold the gradients
    gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))

    # Calculate the gradients for each input and target dimension
    for i in range(X.shape[0]):
        for j in range(10):
            model.zero_grad()
            output = model(X[i])
            # Calculate the gradients
            grads = torch.autograd.grad(output[j], model.parameters())
            # Flatten the gradients and store them
            gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])
            
    return gradients

grads = calculate_gradients(net, X.view(X.shape[0], -1))

Edit: I ran some quick benchmarks of Felix Zimmermann's solution which does indeed provide some nice speedups for this toy problem on my machine.

import time

start = time.time()
for _ in range(1000):
    grads = calculate_gradients(net, X.view(X.shape[0], -1))
end = time.time()
print('Loop solution', end - start)

start = time.time()
for _ in range(1000):
    params = {k: v.detach() for k, v in net.named_parameters()}
    buffers = {k: v.detach() for k, v in net.named_buffers()}
    grads2 = torch.vmap(one_sample)(X.flatten(1))
end = time.time()
print('Vmap solution', end - start)

Which outputs

Loop solution 8.408899307250977
Vmap solution 2.355229139328003

Note that the performance gains are likely to be much greater in more realistic settings with larger batches on GPUs.


Solution

  • To solve this problem, we need three ideas:

    This is all part functorch / torch.func.

    Putting it all together, this does the same as your code:

    # extract the parameters and buffers for a funcional call
    params = {k: v.detach() for k, v in net.named_parameters()}
    buffers = {k: v.detach() for k, v in net.named_buffers()}
    
    def one_sample(sample):
        # this will calculate the gradients for a single sample
        # we want the gradients for each output wrt to the parameters
        # this is the same as the jacobian of the network wrt the parameters
    
        # define a function that takes the as input returns the output of the network
        call = lambda x: torch.func.functional_call(net, (x, buffers), sample)
        
        # calculate the jacobian of the network wrt the parameters
        J = torch.func.jacrev(call)(params)
        
        # J is a dictionary with keys the names of the parameters and values the gradients
        # we want a tensor
        grads = torch.cat([v.flatten(1) for v in J.values()],-1) 
        return grads
    
    # no we can use vmap to calculate the gradients for all samples at once
    grads2 = torch.vmap(one_sample)(X.flatten(1))
    
    print(torch.allclose(grads,grads2))
    

    It should run in parallel, you should try it out for bigger models etc, I did not benchmark it.

    This is also related, to for example Pytorch: Gradient of output w.r.t parameters (which tbh doesn't have a great answer), and pytorch.org/tutorials/intermediate/per_sample_grads.html which shows some of the functions within torch.func for calculating the per sample gradients.