jax.nn.softmax
is defined as:
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
I'm particularly interested in the lax.stop_gradient(x_max)
part. I would love an explanation for why it's needed. From a practical standpoint, it seems that stop_gradient
doesn't change the gradient calculation:
import jax
import jax.numpy as jnp
def softmax_unstable(x):
return jnp.exp(x) / jnp.sum(jnp.exp(x))
def softmax_stable(x):
x = x - jnp.max(x)
return jnp.exp(x) / jnp.sum(jnp.exp(x))
def softmax_stop_gradient(x):
x = x - jax.lax.stop_gradient(jnp.max(x))
return jnp.exp(x) / jnp.sum(jnp.exp(x))
# example input
x = jax.random.normal(jax.random.PRNGKey(123), (100,))
# make sure all forward passes are equal
a = softmax_unstable(x)
b = softmax_stable(x)
c = softmax_stop_gradient(x)
d = jax.nn.softmax(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
# make sure all gradient calculations are the same
a = jax.grad(lambda x: -jnp.log(softmax_unstable(x))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(x))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(x))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(x))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
# make sure all gradient calculations are the same, this time we use softmax functions twice
a = jax.grad(lambda x: -jnp.log(softmax_unstable(softmax_unstable(x)))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(softmax_stable(x)))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(softmax_stop_gradient(x)))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(jax.nn.softmax(x)))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
^ all implementations are equal, even the one where we apply the x - x_max
trick but WITHOUT stop_gradient
.
First off, the reason for subtracting x_max
at all is because it prevents overflow for large inputs. For example:
x = jnp.array([1, 2, 1000])
print(softmax_unstable(x))
# [ 0. 0. nan]
print(softmax_stable(x))
# [0. 0. 1.]
print(softmax_stop_gradient(x))
# [0. 0. 1.]
As for why we use stop_gradient
here, we can show analytically that the max(x)
term cancels-out in the gradient computation, and so we know a priori that its gradient cannot affect the gradient of the overall function. Marking it as stop_gradient
communicates this to JAX's autodiff machinery, leading to a more efficient gradient computation. You can see this efficiency in action by printing the jaxpr for each version of the gradient function:
x = jnp.float32(1)
print(jax.make_jaxpr(jax.grad(softmax_stable))(x))
{ lambda ; a:f32[]. let
b:f32[] = reduce_max[axes=()] a
c:f32[] = reshape[dimensions=None new_sizes=()] b
d:bool[] = eq a c
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
f:f32[] = reduce_sum[axes=()] e
g:f32[] = sub a b
h:f32[] = exp g
i:f32[] = exp g
j:f32[] = reduce_sum[axes=()] i
_:f32[] = div h j
k:f32[] = integer_pow[y=-2] j
l:f32[] = mul 1.0 k
m:f32[] = mul l h
n:f32[] = neg m
o:f32[] = div 1.0 j
p:f32[] = mul n i
q:f32[] = mul o h
r:f32[] = add_any p q
s:f32[] = neg r
t:f32[] = div s f
u:f32[] = mul t e
v:f32[] = add_any r u
in (v,) }
print(jax.make_jaxpr(jax.grad(softmax_stop_gradient))(x))
{ lambda ; a:f32[]. let
b:f32[] = reduce_max[axes=()] a
c:f32[] = reshape[dimensions=None new_sizes=()] b
d:bool[] = eq a c
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
_:f32[] = reduce_sum[axes=()] e
f:f32[] = stop_gradient b
g:f32[] = sub a f
h:f32[] = exp g
i:f32[] = exp g
j:f32[] = reduce_sum[axes=()] i
_:f32[] = div h j
k:f32[] = integer_pow[y=-2] j
l:f32[] = mul 1.0 k
m:f32[] = mul l h
n:f32[] = neg m
o:f32[] = div 1.0 j
p:f32[] = mul n i
q:f32[] = mul o h
r:f32[] = add_any p q
in (r,) }
The second version requires fewer computations to achieve the same result, because we've essentially told the autodiff machinery it does not have to worry about differentiating max(x)
.