Search code examples
pythonjax

JAX: average per-example gradients different from aggregated gradients


I want to compute per-example gradients to perform per-example clipping before the final gradient descent step.

I wanted to ensure that the per-example gradients are correct. Therefore I implemented this minimal example below to test standard gradient computation and per-example gradient computation with JAX.

The problem I have is, that the average of the per-example gradients differs from the gradients of the standard computation.

Does someone see where I went wrong?

import jax
import jax.numpy as jnp
from jax import random


def loss(params, x, t):
    w, b = params
    y = jnp.dot(w, x.T) + b
    return ((t.T - y)**2).sum()


def main():

    n_samples = 3
    dims_in = 7
    dims_out = 5

    key = random.PRNGKey(0)

    # Random data
    x = random.normal(key, (n_samples, dims_in), dtype=jnp.float32)
    t = random.normal(key, (n_samples, dims_out), dtype=jnp.float32)
    
    # Random weights
    w = random.normal(key, (dims_out, dims_in), dtype=jnp.float32)
    b = random.normal(key, (dims_out, 1), dtype=jnp.float32)
    params = (w, b)

    # Standard gradient
    reduced_grads = jax.grad(loss)
    dw0, db0 = reduced_grads(params, x, t)
    print(f"{dw0.shape = }") 
    print(f"{db0.shape = }") 

    # Per-example gradients
    perex_grads = jax.vmap(jax.grad(loss), in_axes=((None, None), 0, 0))
    dw1, db1 = perex_grads(params, x, t)
    print(f"{dw1.shape = }") 
    print(f"{db1.shape = }")

    # Gradients are different!
    print(jnp.allclose(dw0, jnp.mean(dw1, axis=0)))  # should be True
    print(jnp.allclose(db0, jnp.mean(db1, axis=0)))  # should be True
    

if __name__ == "__main__":
    main()

Solution

  • THe issue is that vmap effectively passes 1D inputs to your function, and your loss function operates differently on 1D vs. 2D inputs. You can see this by comparing the following:

    print(loss(params, x[0], t[0]))
    # 87.5574
    
    print(loss(params, x[:1], t[:1]))
    # 18.089672
    

    If you inspect the shapes of the intermediate results within the loss function, it should become clear why these differ (e.g. you're summing over a shape (5, 5) array of differences instead of a shape (5, 1) array of differences).

    To use per-example gradients, you'll need to change the implementation of your loss function so that it returns the correct result whether the input is 1D or 2D. The easiest way to do this is probably to use jnp.atleast_2d to ensure that the inputs are two-dimensional:

    def loss(params, x, t):
        x = jnp.atleast_2d(x)
        t = jnp.atleast_2d(t)
        w, b = params
        y = jnp.dot(w, x.T) + b
        return ((t.T - y)**2).sum()
    

    At this point the sum (not the mean) of the per-example gradients will match the full computation:

    print(jnp.allclose(dw0, jnp.sum(dw1, axis=0)))
    # True
    print(jnp.allclose(db0, jnp.sum(db1, axis=0)))
    # True