I'm trying to optimize a computation in PyTorch by first identifying the unique elements of a tensor, applying an expensive function (e.g., torch.exp) to these unique elements only, and then mapping the results back to the original tensor's shape to calculate the gradient with respect to the original tensor.
My motivation is to avoid redundant calculations of the expensive function for duplicate values in the input tensor, which would significantly improve performance.
Here's the code snippet that demonstrates my approach, but it results in an error since unique is not differentiable:
import torch
inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)
unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
print(f"There are {unique_inputs.numel()} unique elements.")
unique_exp = torch.exp(unique_inputs)
full_exp = unique_exp[inverse_indices]
torch.autograd.grad(full_exp[0], inputs) # <-- Error here
Is there another way I can do that?
You can do this with a custom backward
function:
import torch
import torch.nn as nn
class UniqueForward(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs):
unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
unique_counts = torch.bincount(inverse_indices)
ctx.save_for_backward(inverse_indices, unique_counts)
ctx.input_shape = inputs.shape
return unique_inputs, inverse_indices
@staticmethod
def backward(ctx, grad_unique, grad_inverse):
inverse_indices, unique_counts = ctx.saved_tensors
grad_inputs = grad_unique[inverse_indices]
grad_inputs = grad_inputs / unique_counts[inverse_indices]
return grad_inputs
class EfficientFunction(nn.Module):
def __init__(self, function):
# function should be a callable element-wise function
super().__init__()
self.function = function
self.unique_forward = UniqueForward.apply
def forward(self, inputs):
unique_inputs, inverse_indices = self.unique_forward(inputs)
unique_results = self.function(unique_inputs)
full_results = unique_results[inverse_indices]
return full_results
The EfficientFunction
module should work for any case where 1) the inputs
have duplicated values and 2) the function
in question applies an element-wise operation.
Example:
def expensive_function(x):
# use exp as example of expensive function
return torch.exp(x)
efficient_function = EfficientFunction(expensive_function)
# test efficient version
inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)
outputs = efficient_function(inputs)
loss = outputs.mean()
loss.backward()
# test naive version
inputs2 = inputs.clone().detach().requires_grad_(True)
naive_results = expensive_function(inputs2)
naive_loss = naive_results.mean()
naive_loss.backward()
# compare gradients
torch.allclose(inputs.grad, inputs2.grad)