Search code examples
juliaderivativejaxautomatic-differentiationautodiff

Confused about evaluating vector-Jacobian-product with non-identity vectors (JAX)


I'm confused about the meaning of evaluating vector-Jacobian-products when the vector used for the VJP is a non-identity row vector. My question pertains to vector-valued functions, not scalar functions like loss. I will show a concrete example using Python and JAX but this is a very general question about reverse-mode automatic differentiation.

Consider this simple vector-valued function for which the Jacobian is trivial to write down analytically:

from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import vjp, jacrev

# Define a vector-valued function (3 inputs --> 2 outputs) 
def vector_func(args):
    x,y,z = args
    a = 2*x**2 + 3*y**2 + 4*z**2
    b = 4*x*y*z
    return jnp.array([a, b])

# Define the inputs
x = 2.0
y = 3.0
z = 4.0

# Compute the vector-Jacobian product at the fiducial input point (x,y,z)
val, func_vjp = vjp(vector_func, (x, y, z))

print(val) 
# [99,96]

# now evaluate the function returned by vjp along with basis row vectors to pull out gradient of 1st and 2nd output components 
v1 = jnp.array([1.0, 0.0])  # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., first row of Jacobian
v2 = jnp.array([0.0, 1.0])  # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., second row of Jacobian 

gradient1 = func_vjp(v1)
print(gradient1)
# [8, 18, 32]

gradient2 = func_vjp(v2)
print(gradient2)
# [48,32,24]

That much makes sense to me -- we're separately feeding [1,0] and [0,1] to vjp_func to respectively get the first and second rows of the Jacobian evaluated at our fiducial point (x,y,z)=(2,3,4).

But now what if we fed vjp_func a non-identity row vector like [2,0]? Is this asking how the fiducial (x,y,z) would need to be perturbed to double the first component of the output? If so, is there a way to see this by evaluating vector_func at the perturbed parameter values?

I tried but I'm not sure:

# suppose I want to know what perturbations in (x,y,z) cause a doubling of the first output and no change in second output component 
print(func_vjp(jnp.array([2.0,0.0])))
# [16,36,64] 

### Attempts to use the output of vjp_func to verify that val becomes [99*2, 96]
### none of these work

print(vector_func([16,36,64]))
# [20784, 147456]

print(vector_func([x*16,y*36,z*64])
# [299184., 3538944.]

What am I doing wrong in using the output of func_vjp to modify the fiducial parameters (x,y,z) and feed those back into vector_func to verify indeed that those parameter perturbations double the first component of the original output and leave the second component unchanged?


Solution

  • I think in your question you are confusing primal and tangent vector spaces. The function vector_func is a non-linear function that maps a vector in an input primal vector space (represented by (x, y, z)) to a vector in an output primal vector space (represented by val in your code).

    The function func_vjp is a linear function that maps a vector in an output tangent vector space (represented by array([2, 0]) in your question) to a vector in an input tangent vector space ([16,36,64] in your question).

    By construction, the tangent vectors in these transformations represent the gradients of the input function at the specified primal values. That is, if you infinitesimally perturb your output primal along the direction of your output tangent, it corresponds to infinitesimally perturbing the input primal along the direction of the input tangent.

    If you want to check the values, you could do something like this:

    input_primal = (x, y, z)
    output_primal, func_vjp = vjp(vector_func, input_primal)
    
    epsilon = 1E-8  # note: small value so we're near the linear regime
    output_tangent = epsilon * jnp.array([0.0, 1.0])
    input_tangent, = func_vjp(output_tangent)
    
    # Compute the perturbed output given the perturbed input
    perturbed_input = [p + t for p, t in zip(input_primal, input_tangent)]
    perturbed_output_1 = vector_func(perturbed_input)
    print(perturbed_output_1)
    # [99.00001728 96.00003904]
    
    # Perturb the output directly
    perturbed_output_2 = output_primal + output_tangent
    print(perturbed_output_2)
    # [99.         96.00000001]
    

    Note that the results don't match exactly, because the VJP is valid in the locally linear limit, and your function is very nonlinear. But hopefully this helps clarify what these primal and tangent values mean in the context of the VJP computation. Mathematically, if we computed this in the limit where epsilon goes to zero, the results would match exactly – gradient computations are all about these kinds of infinitesimal limits.