Search code examples
pythonvectorizationjaxautomatic-differentiation

JAX `vjp` fails for vmapped function with `custom_vjp`


Below is an example where a function with a custom-defined vector-Jacobian product (custom_vjp) is vmapped. For a simple function like this, invoking vjp fails:

@partial(custom_vjp, nondiff_argnums=(0,))
def test_func(f: Callable[..., float],
              R: Array
              ) -> float:

    return f(jnp.dot(R, R))


def test_func_fwd(f, primal):

    primal_out = test_func(f, primal)
    residual = 2. * primal * primal_out
    return primal_out, residual


def test_func_bwd(f, residual, cotangent):

    cotangent_out = residual * cotangent
    return (cotangent_out, )


test_func.defvjp(test_func_fwd, test_func_bwd)

test_func = vmap(test_func, in_axes=(None, 0))


if __name__ == "__main__":

    def f(x):
        return x

    # vjp
    primal, f_vjp = vjp(partial(test_func, f),
                        jnp.ones((10, 3))
                        )

    cotangent = jnp.ones(10)
    cotangent_out = f_vjp(cotangent)

    print(cotangent_out[0].shape)

The error message says:

ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).

Here, I think the error message is misleading, because the cotangent input should have the same shape as the primal output, which should be (10, ) in this case. Still, it's not clear to me why this error occurs.


Solution

  • The problem is that in test_func_fwd, you recursively call test_func, but you've overwritten test_func in the global namespace with its vmapped version. If you leave the original test_func unchanged in the global namespace, your code will work as expected:

    ...
    
    test_func_mapped = vmap(test_func, in_axes=(None, 0))
    
    ... 
    
    primal, f_vjp = vjp(partial(test_func_mapped, f),
                        jnp.ones((10, 3))
                        )