Search code examples
pythonpytorchfftautograd

Pytorch Autograd with complex fourier transforms gives wrong results


I am trying to implement a real valued cost function that evaluates a complex input in frequency space with pytorch & autograd since I am interested in the gradients of the cost function w.r.t. the input. When I compare the autograd results with the derivative that I computed by hand (with Wirtinger calculus) I get a different result. I'm not sure where I made the mistake, whether it is in my implementation or in my own derivation of the gradient.

The cost function and its derivative by hand looks like this: Formula of the cost function

My implementation is here

def f_derivative_by_hand(f):
    f = torch.tensor(f, dtype=torch.complex128)
    ftilde = torch.fft.fft(f)
    absf = torch.abs(ftilde)
    f2 = absf**2
    C = torch.trapz(f2).numpy()
    grads = 2 * torch.fft.ifft((ftilde)).numpy()
    return C, grads

def f_derivative_autograd(f):
    f = torch.tensor(f, dtype=torch.complex128, requires_grad=True)
    ftilde = torch.fft.fft(f)
    f2 = torch.abs(ftilde)**2
    C = torch.trapz(f2)
    C.backward()
    grads = f.grad
    return C.detach().numpy(), grads.detach().numpy()

When I use some data and evaluate it by both functions, the gradients of the implementation with automatic differentiation is tilted in comparison (note that I normalized the plotted arrays): Autograd and derivative by hand comparison

I suspect there could also be something wrong with the automatic differentiation of fft though since if I remove the fourier transform from the cost function and integrate the function in real space, both implementations match exactly except at the edges (again normalized): No FFT autograd and derivative by hand

It would be fantasic if someone could help me figure out what is wrong!


Solution

  • After some more investigation, I found the solution to the problem of the tilted derivatives. Apparently, the trapezoidal integration rule assumes boundary conditions that will show some artifacts at the boundaries as discussed in this pytorch forum post.

    In my original problem, the observed tilt results from the integration of the fourier transformed signal which is asymmetric. The boundary artifacts introduce spatial frequencies which tilt the derivative in real space.

    For me, the simplest solution is just to use a sum and weight by the frequency differential. Then, everything works out.