I have a relatively simple requirement but surprisingly this does not seem to be straightforward to implement in pytorch. Given a neural network with $P$ parameters that outputs a vector of length $Y$ and a batch of $B$ data inputs, I would like to calculate the gradients of the outputs with respect to the model's parameters.
In other words, I would like the following function:
def calculate_gradients(model, X):
"""
Args:
nn module with P parameters in total that outputs a tensor of size (B, Y).
torch tensor of shape (B, .).
Returns:
torch tensor of shape (B, Y, P)
"""
# function logic here
Unfortunately, I don't currently see an obvious way of calculating this efficiently, especially without aggregating over the data or target dimensions. A minimal working example below involves looping over input and target dimensions, but surely there is a more efficient way?
import torch
from torchvision import datasets, transforms
import torch.nn as nn
###### SETUP ######
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
h = self.fc1(x)
pred = self.fc2(self.relu(h))
return pred
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)
X, y = next(iter(train_dataloader)) # take a random batch of data
net = MLP(28*28, 20, 10) # define a network
###### CALCULATE GRADIENTS ######
def calculate_gradients(model, X):
# Create a tensor to hold the gradients
gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))
# Calculate the gradients for each input and target dimension
for i in range(X.shape[0]):
for j in range(10):
model.zero_grad()
output = model(X[i])
# Calculate the gradients
grads = torch.autograd.grad(output[j], model.parameters())
# Flatten the gradients and store them
gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])
return gradients
grads = calculate_gradients(net, X.view(X.shape[0], -1))
Edit: I ran some quick benchmarks of Felix Zimmermann's solution which does indeed provide some nice speedups for this toy problem on my machine.
import time
start = time.time()
for _ in range(1000):
grads = calculate_gradients(net, X.view(X.shape[0], -1))
end = time.time()
print('Loop solution', end - start)
start = time.time()
for _ in range(1000):
params = {k: v.detach() for k, v in net.named_parameters()}
buffers = {k: v.detach() for k, v in net.named_buffers()}
grads2 = torch.vmap(one_sample)(X.flatten(1))
end = time.time()
print('Vmap solution', end - start)
Which outputs
Loop solution 8.408899307250977
Vmap solution 2.355229139328003
Note that the performance gains are likely to be much greater in more realistic settings with larger batches on GPUs.
To solve this problem, we need three ideas:
The gradients of the outputs with respect the the parameters is the Jacobian of the network wrt the parameters. https://pytorch.org/functorch/stable/generated/functorch.jacrev.html
We can functionalize a pytorch model, that is transform a model into a function of its parameters https://pytorch.org/functorch/nightly/generated/functorch.functionalize.html
Pytorch can vectorize over many operations using vmap https://pytorch.org/functorch/stable/generated/functorch.vmap.html
This is all part functorch
/ torch.func
.
Putting it all together, this does the same as your code:
# extract the parameters and buffers for a funcional call
params = {k: v.detach() for k, v in net.named_parameters()}
buffers = {k: v.detach() for k, v in net.named_buffers()}
def one_sample(sample):
# this will calculate the gradients for a single sample
# we want the gradients for each output wrt to the parameters
# this is the same as the jacobian of the network wrt the parameters
# define a function that takes the as input returns the output of the network
call = lambda x: torch.func.functional_call(net, (x, buffers), sample)
# calculate the jacobian of the network wrt the parameters
J = torch.func.jacrev(call)(params)
# J is a dictionary with keys the names of the parameters and values the gradients
# we want a tensor
grads = torch.cat([v.flatten(1) for v in J.values()],-1)
return grads
# no we can use vmap to calculate the gradients for all samples at once
grads2 = torch.vmap(one_sample)(X.flatten(1))
print(torch.allclose(grads,grads2))
It should run in parallel, you should try it out for bigger models etc, I did not benchmark it.
This is also related, to for example Pytorch: Gradient of output w.r.t parameters (which tbh doesn't have a great answer), and pytorch.org/tutorials/intermediate/per_sample_grads.html which shows some of the functions within torch.func for calculating the per sample gradients.