Search code examples
pythonmachine-learningdeep-learningautogradjax

Purpose of stop gradient in `jax.nn.softmax`?


jax.nn.softmax is defined as:

def softmax(x: Array,
            axis: Optional[Union[int, Tuple[int, ...]]] = -1,
            where: Optional[Array] = None,
            initial: Optional[Array] = None) -> Array:
  x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
  unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
  return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)

I'm particularly interested in the lax.stop_gradient(x_max) part. I would love an explanation for why it's needed. From a practical standpoint, it seems that stop_gradient doesn't change the gradient calculation:

import jax
import jax.numpy as jnp

def softmax_unstable(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x))

def softmax_stable(x):
    x = x - jnp.max(x)
    return jnp.exp(x) / jnp.sum(jnp.exp(x))

def softmax_stop_gradient(x):
    x = x - jax.lax.stop_gradient(jnp.max(x))
    return jnp.exp(x) / jnp.sum(jnp.exp(x))

# example input
x = jax.random.normal(jax.random.PRNGKey(123), (100,))

# make sure all forward passes are equal
a = softmax_unstable(x)
b = softmax_stable(x)
c = softmax_stop_gradient(x)
d = jax.nn.softmax(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)

# make sure all gradient calculations are the same
a = jax.grad(lambda x: -jnp.log(softmax_unstable(x))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(x))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(x))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(x))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)

# make sure all gradient calculations are the same, this time we use softmax functions twice
a = jax.grad(lambda x: -jnp.log(softmax_unstable(softmax_unstable(x)))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(softmax_stable(x)))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(softmax_stop_gradient(x)))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(jax.nn.softmax(x)))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)

^ all implementations are equal, even the one where we apply the x - x_max trick but WITHOUT stop_gradient.


Solution

  • First off, the reason for subtracting x_max at all is because it prevents overflow for large inputs. For example:

    x = jnp.array([1, 2, 1000])
    
    print(softmax_unstable(x))
    # [ 0.  0. nan]
    print(softmax_stable(x))
    # [0. 0. 1.]
    print(softmax_stop_gradient(x))
    # [0. 0. 1.]
    

    As for why we use stop_gradient here, we can show analytically that the max(x) term cancels-out in the gradient computation, and so we know a priori that its gradient cannot affect the gradient of the overall function. Marking it as stop_gradient communicates this to JAX's autodiff machinery, leading to a more efficient gradient computation. You can see this efficiency in action by printing the jaxpr for each version of the gradient function:

    x = jnp.float32(1)
    print(jax.make_jaxpr(jax.grad(softmax_stable))(x))
    
    { lambda ; a:f32[]. let
        b:f32[] = reduce_max[axes=()] a
        c:f32[] = reshape[dimensions=None new_sizes=()] b
        d:bool[] = eq a c
        e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
        f:f32[] = reduce_sum[axes=()] e
        g:f32[] = sub a b
        h:f32[] = exp g
        i:f32[] = exp g
        j:f32[] = reduce_sum[axes=()] i
        _:f32[] = div h j
        k:f32[] = integer_pow[y=-2] j
        l:f32[] = mul 1.0 k
        m:f32[] = mul l h
        n:f32[] = neg m
        o:f32[] = div 1.0 j
        p:f32[] = mul n i
        q:f32[] = mul o h
        r:f32[] = add_any p q
        s:f32[] = neg r
        t:f32[] = div s f
        u:f32[] = mul t e
        v:f32[] = add_any r u
      in (v,) }
    
    print(jax.make_jaxpr(jax.grad(softmax_stop_gradient))(x))
    
    { lambda ; a:f32[]. let
        b:f32[] = reduce_max[axes=()] a
        c:f32[] = reshape[dimensions=None new_sizes=()] b
        d:bool[] = eq a c
        e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
        _:f32[] = reduce_sum[axes=()] e
        f:f32[] = stop_gradient b
        g:f32[] = sub a f
        h:f32[] = exp g
        i:f32[] = exp g
        j:f32[] = reduce_sum[axes=()] i
        _:f32[] = div h j
        k:f32[] = integer_pow[y=-2] j
        l:f32[] = mul 1.0 k
        m:f32[] = mul l h
        n:f32[] = neg m
        o:f32[] = div 1.0 j
        p:f32[] = mul n i
        q:f32[] = mul o h
        r:f32[] = add_any p q
      in (r,) }
    

    The second version requires fewer computations to achieve the same result, because we've essentially told the autodiff machinery it does not have to worry about differentiating max(x).