Search code examples
arraysnumpyroundingjax

Why do I get different values from jnp.round and np.round?


I'm writing tests for some jax code and using np.testing.assert_array...-type functions and came across this difference in values that I didn't expect:

import jax.numpy as jnp
import numpy as np
from numpy.testing import assert_array_equal

a = jnp.array([-0.78073686, -0.7908204 ,  2.174842])
b = np.array(a, dtype='float32')
assert_array_equal(a, b)

print(a.round(2), a.dtype)
print(b.round(2), b.dtype)

Output:

[-0.78       -0.78999996  2.1699998 ] float32
[-0.78 -0.79  2.17] float32

Test:

assert_array_equal(a.round(2), b.round(2))

Output:

AssertionError: 
Arrays are not equal

Mismatched elements: 2 / 3 (66.7%)
Max absolute difference: 2.3841858e-07
Max relative difference: 1.0987031e-07
 x: array([-0.78, -0.79,  2.17], dtype=float32)
 y: array([-0.78, -0.79,  2.17], dtype=float32)

Footnote:

I get exactly the same results if I define b as follows, so it's not a problem with the conversion of the array from jax to numpy:

b = np.array([-0.78073686, -0.7908204 ,  2.174842], dtype='float32')

Solution

  • This is an example of a general property of floating point computations: two different ways of expressing the same computation will not always produce bitwise-equivalent outputs (see e.g. Is floating point math broken?).

    JAX and NumPy use identical implementations for x.round(2); essentially it is round_to_int(x * 100) / 100 (compare the JAX implementation and the NumPy implementation).

    The difference is that JAX jit-compiles jnp.round by default. When you disable compilation and perform these operations in sequence, the results are identical:

    import jax
    with jax.disable_jit():
      assert_array_equal(a.round(2), b.round(2))  # passes!
    

    But JAX's JIT optimizes the implementation by fusing some operations – this leads to faster computation but in general you should not expect the result to be bitwise-equivalent to the unoptimized version.

    To address this, whenever you are comparing floating point values, you should avoid exact equality checks in favor of checks that take this floating point roundoff error into account. For example:

    np.testing.assert_allclose(a.round(2), b.round(2), rtol=1E-6)  # passes!