Search code examples
pytorchautomatic-differentiation

How to use torch.unique to filter duplicate values, calculate an expensive function, map it back, and then calculate the gradient?


I'm trying to optimize a computation in PyTorch by first identifying the unique elements of a tensor, applying an expensive function (e.g., torch.exp) to these unique elements only, and then mapping the results back to the original tensor's shape to calculate the gradient with respect to the original tensor.

My motivation is to avoid redundant calculations of the expensive function for duplicate values in the input tensor, which would significantly improve performance.

Here's the code snippet that demonstrates my approach, but it results in an error since unique is not differentiable:

import torch

inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)

unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
print(f"There are {unique_inputs.numel()} unique elements.")

unique_exp = torch.exp(unique_inputs)
full_exp = unique_exp[inverse_indices]

torch.autograd.grad(full_exp[0], inputs) # <-- Error here

Is there another way I can do that?


Solution

  • You can do this with a custom backward function:

    import torch
    import torch.nn as nn
    
    class UniqueForward(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inputs):
            unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
            unique_counts = torch.bincount(inverse_indices)
            ctx.save_for_backward(inverse_indices, unique_counts)
            ctx.input_shape = inputs.shape
            return unique_inputs, inverse_indices
    
        @staticmethod
        def backward(ctx, grad_unique, grad_inverse):
            inverse_indices, unique_counts = ctx.saved_tensors
            grad_inputs = grad_unique[inverse_indices]
            grad_inputs = grad_inputs / unique_counts[inverse_indices]
            return grad_inputs
    
    class EfficientFunction(nn.Module):
        def __init__(self, function):
            # function should be a callable element-wise function
            super().__init__()
            self.function = function
            self.unique_forward = UniqueForward.apply
            
        def forward(self, inputs):
            unique_inputs, inverse_indices = self.unique_forward(inputs)
            unique_results = self.function(unique_inputs)
            full_results = unique_results[inverse_indices]
            return full_results
    

    The EfficientFunction module should work for any case where 1) the inputs have duplicated values and 2) the function in question applies an element-wise operation.

    Example:

    def expensive_function(x):
        # use exp as example of expensive function
        return torch.exp(x)
    
    efficient_function = EfficientFunction(expensive_function)
    
    # test efficient version
    inputs = torch.rand(100)
    inputs = torch.round(inputs, decimals=2)
    inputs.requires_grad_(True)
    
    outputs = efficient_function(inputs)
    loss = outputs.mean()
    loss.backward()
    
    # test naive version
    inputs2 = inputs.clone().detach().requires_grad_(True)
    naive_results = expensive_function(inputs2)
    naive_loss = naive_results.mean()
    naive_loss.backward()
    
    # compare gradients 
    torch.allclose(inputs.grad, inputs2.grad)