Search code examples
tensorflowautogradautomatic-differentiationjaxautodiff

Automatic Differentiation with respect to rank-based computations


I'm new to automatic differentiation programming, so this maybe a naive question. Below is a simplified version of what I'm trying to solve.

I have two input arrays - a vector A of size N and a matrix B of shape (N, M), as well a parameter vector theta of size M. I define a new array C(theta) = B * theta to get a new vector of size N. I then obtain the indices of elements that fall in the upper and lower quartile of C, and use them to create a new array A_low(theta) = A[lower quartile indices of C] and A_high(theta) = A[upper quartile indices of C]. Clearly these two do depend on theta, but is it possible to differentiate A_low and A_high w.r.t theta?

My attempts so far seem to suggest no - I have using the python libraries of autograd, JAX and tensorflow, but they all return a gradient of zero. (The approaches I have tried so far involve using argsort or extracting the relevant sub-arrays using tf.top_k.)

What I'm seeking help with is either a proof that the derivative is not defined (or cannot be analytically computed) or if it does exist, a suggestion on how to estimate it. My eventual goal is to minimize some function f(A_low, A_high) wrt theta.


Solution

  • This is the JAX computation that I wrote based on your description:

    import numpy as np
    import jax.numpy as jnp
    import jax
    
    N = 10
    M = 20
    
    rng = np.random.default_rng(0)
    A = jnp.array(rng.random((N,)))
    B = jnp.array(rng.random((N, M)))
    theta = jnp.array(rng.random(M))
    
    def f(A, B, theta, k=3):
      C = B @ theta
      _, i_upper = lax.top_k(C, k)
      _, i_lower = lax.top_k(-C, k)
      return A[i_lower], A[i_upper]
    
    x, y = f(A, B, theta)
    dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)
    

    The derivatives are all zero, and I believe this is correct, because the change in value of the outputs does not depend on the change in value of theta.

    But, you might ask, how can this be? After all, theta enters into the computation, and if you put in a different value for theta, you get different outputs. How could the gradient be zero?

    What you must keep in mind, though, is that differentiation doesn't measure whether an input affects an output. It measures the change in output given an infinitesimal change in input.

    Let's use a slightly simpler function as an example:

    import jax
    import jax.numpy as jnp
    
    A = jnp.array([1.0, 2.0, 3.0])
    theta = jnp.array([5.0, 1.0, 3.0])
    
    def f(A, theta):
      return A[jnp.argmax(theta)]
    
    x = f(A, theta)
    dx_dtheta = jax.grad(f, argnums=1)(A, theta)
    

    Here the result of differentiating f with respect to theta is all zero, for the same reasons as above. Why? If you make an infinitesimal change to theta, it will in general not affect the sort order of theta. Thus, the entries you choose from A do not change given an infinitesimal change in theta, and thus the derivative with respect to theta is zero.

    Now, you might argue that there are circumstances where this is not the case: for example, if two values in theta are very close together, then certainly perturbing one even infinitesimally could change their respective rank. This is true, but the gradient resulting from this procedure is undefined (the change in output is not smooth with respect to the change in input). The good news is this discontinuity is one-sided: if you perturb in the other direction, there is no change in rank and the gradient is well-defined. In order to avoid undefined gradients, most autodiff systems will implicitly use this safer definition of a derivative for rank-based computations.

    The result is that the value of the output does not change when you infinitesimally perturb the input, which is another way of saying the gradient is zero. And this is not a failure of autodiff – it is the correct gradient given the definition of differentiation that autodiff is built on. Moreover, were you to try changing to a different definition of the derivative at these discontinuities, the best you could hope for would be undefined outputs, so the definition that results in zeros is arguably more useful and correct.