Search code examples
pythonnanjax

Jax - autograd of a sigmoid always returns nan


I am trying to differentiate a function that approximates the fraction of a gaussian that is contained within 2 limits (a truncated gaussian) given a shifted mean. jnp.grad does not let me differentiate adding up boolean filters (the commented line) so I've had to improvise with a sigmoid instead.

However, now the the gradient is always nan when the truncation boundary is high and I don't understand why.

In the example below I'm calculating the gradient of a gaussian with 0 mean and std=1, which I then shift around with x.

If I decrease the boundary, then the function behaves as expected. But this is not a solution. When the boundary is high then belows becomes 1 all the time. But is this is the case and x has no impact on below, then its contribution to the gradient should be 0 not nan. But if I return belows[0][0] instead of the jnp.mean(filt, axis=0), I still get nan.

Any ideas? Thanks in advance (there is an issue open onf github as well)

import os

from tqdm import tqdm

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' # Use 8 CPU devices
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
from jax import vmap

from functools import reduce

def sigmoid(x, scale=100):
    return 1 / (1 + jnp.exp(-x*scale))

def above_lower(x, l, scale=100):
    return sigmoid(x - l, scale)

def below_upper(x, u, scale=100):
    return 1 - sigmoid(x - u, scale)

def combine_soft_filters(a):
    return jnp.prod(jnp.stack(a), axis=0)


def fraction_not_truncated(mu, v, limits, stdnorm_samples):
    L = jnp.linalg.cholesky(v)
    y = vmap(lambda x: jnp.dot(L, x))(stdnorm_samples) + mu
    # filt = reduce(jnp.logical_and, [(y[..., i] > l) & (y[..., i] < u) for i, (l, u) in enumerate(limits)])
    aboves = [above_lower(y[..., i], l) for i, (l, u) in enumerate(limits)]
    belows = [below_upper(y[..., i], u) for i, (l, u) in enumerate(limits)]
    filt = combine_soft_filters(aboves+belows)
    return jnp.mean(filt, axis=0)

limits = np.array([
        [0.,1000],
])

stdnorm_samples = np.random.multivariate_normal([0], np.eye(1), size=1000)

def func(x):
    return fraction_not_truncated(jnp.zeros(1)+x, jnp.eye(1), limits, stdnorm_samples)

_x = np.linspace(-2, 2, 500)
gradfunc = jax.grad(func)
vals = [func(x) for x in tqdm(_x)]
grads = [gradfunc(x) for x in tqdm(_x)]
print(vals)
print(grads)
import matplotlib.pyplot as plt
plt.plot(_x, np.asarray(vals))
plt.ylabel('f(x)')
plt.twinx()
plt.plot(_x, np.asarray(grads), c='r')
plt.ylabel("f(x)'")
plt.title('Fraction not truncated')
plt.axhline(0, color='k', alpha=0.2)
plt.xlabel('shift')
plt.tight_layout()
plt.show()

enter image description here

[DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64)]
[DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)]

Solution

  • The issue is that your sigmoid function is implemented in such a way that the automatically determined gradient is not stable for large negative values of x:

    import jax.numpy as jnp
    import jax
    
    def sigmoid(x, scale=100):
        return 1 / (1 + jnp.exp(-x*scale))
    
    print(jax.grad(sigmoid)(-1000.0))
    # nan
    

    You can see why this is happening using the jax.make_jaxpr function to introspect the operations produced by the automatically determined gradient (comments are my annotations):

    >>> jax.make_jaxpr(jax.grad(sigmoid))(-1000.0)
    { lambda  ; a.                    # a = -1000
      let b = neg a                   # b = 1000
          c = mul b 100.0             # c = 100,000
          d = exp c                   # d = inf
          e = add d 1.0
          _ = div 1.0 e
          f = integer_pow[ y=-2 ] e   # f = 0
          g = mul 1.0 f               # g = 0
          h = mul g 1.0               # h = 0
          i = neg h                   # i = 0
          j = mul i d                 # j = 0 * inf = NaN
          k = mul j 100.0             # k = NaN
          l = neg k                   # l = NaN
      in (l,) }                       # return NaN
    

    This is one of those cases where 64-bit floating point arithmetic fails you: it doesn't have the range to work with numbers like exp(100000).

    So what can you do? One heavy-weight option would be to use a custom derivative rule to tell autodiff how to handle the sigmoid function in a more stable way. In this case, though, an easier option is to re-express the sigmoid function in terms of something that is better behaved under autodiff transformations. One option is this:

    def sigmoid(x, scale=100):
        return 0.5 * (jnp.tanh(x * scale / 2) + 1)
    

    Using this version in your script fixes the issue.