Search code examples
juliadifferentiation

Julia: Zygote.@adjoint from Enzyme.autodiff


Given the function f! below :

function f!(s::Vector, a::Vector, b::Vector)
  
  s .= a .+ b
  return nothing

end # f!

How can I define an adjoint for Zygote based on

Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b))) ?

Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?

Solution

  • Could figure out a way, sharing it here.

    For a given function foo, Zygote.pullback(foo, args...) returns foo(args...) and the backward pass (which allows for gradients computations).

    My goal is to tell Zygote to use Enzyme for the backward pass.

    This can be done by means of Zygote.@adjoint (see more here).

    In case of array-valued functions, Enzyme requires a mutating version that returns nothing and its result to be in args (see more here).

    The function f! in the question post is an Enzyme-compatible version of a sum of two arrays.

    Since f! returns nothing, Zygote would simply return nothing when the backward pass is called on some gradient passed to us.

    A solution is to place f! inside a wrapper (say f) that returns the array s

    and to define Zygote.@adjoint for f, rather than f!.

    Hence,

    function f(a::Vector, b::Vector)
    
      s = zero(a)
      f!(s, a, b)
      return s
    
    end
    
    function enzyme_back(dzds, a, b)
    
      s    = zero(a)
      dzda = zero(dzds)
      dzdb = zero(dzds)
      Enzyme.autodiff(
        f!,
        Const,
        Duplicated(s, dzds),
        Duplicated(a, dzda),
        Duplicated(b, dzdb)
      )
      return (dzda, dzdb)
    
    end
    

    and

    Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)
    

    inform Zygote to use Enzyme in the backward pass.


    Finally, you can check that calling Zygote.gradient either on

    g1(a::Vector, b::Vector) = sum(abs2, a + b)
    

    or

    g2(a::Vector, b::Vector) = sum(abs2, f(a, b))
    

    yields the same results.