Search code examples
pythonautogradautomatic-differentiation

How to assign equations element by element in autograd


I am trying to implement an autograd-based solver for a nonlinear PDE. As with most PDE's, I need to be able to operate in individual entries of my input vector, but apparently this breaks autograd. I have created this simple example to show the issue I'm facing:

The following code works:

def my_equation(x):
    eq = x
    return eq

x = np.random.randn(2,)
jac = autograd.jacobian(my_equation)
jacval = jac(x)
print(jacval)

The following code does not work:

def my_equation(x):
    eq = x
    # This breaks the code, although is a 
    # trivial line
    eq[1] = x[1]
    return eq

x = np.random.randn(2,)
jac = autograd.jacobian(my_equation)
jacval = jac(x)
print(jacval)

I've read in a couple of places that you can't assign elements in autograd. Is this actually true. Is there any workaround? Or maybe another library to suggest?

Thank you!


Solution

  • Indeed, array indexing assignment is not possible in autograd. People have written PDE solvers in autograd (see https://github.com/HIPS/autograd/tree/master/examples/fluidsim) so perhaps there is a way to solve your problem while staying in autograd.

    JAX appears to offer a workaround with the jax.ops package (see https://jax.readthedocs.io/en/latest/jax.ops.html and https://github.com/google/jax#current-gotchas).

    It appears that array indexing is possible in PyTorch. That suggests PyTorch would be your way to go. The following code works.

    import torch
    
    def f(x):
        eq = 2*x 
        eq[0] = x[0] 
        return eq
    
    x = torch.rand(4, requires_grad=True) 
    y = f(x)
    z = torch.sum(y)
    z.backward()
    print(x.grad) # prints [1., 2., 2., 2.]
    
    

    "Or maybe another library to suggest?"

    Take a look at dolfin-adjoint. If you can write your PDE solver in FEniCS, then it could be helpful.

    http://www.dolfin-adjoint.org/en/latest/

    Note that a naive back propagation through a non-linear equation solver might be an inefficient way to compute the derivatives of a scalar loss function, see https://cs.stanford.edu/~ambrad/adjoint_tutorial.pdf for a pretty solid tutorial on the adjoint method, including for nonlinear problems.