Search code examples
pythonjax

The kernel dies with jax.random.PGRNKey


Kernel dies

The kernel dies with jax.random.PGRNKey. python version is 3.10.10. jax version is 0.4.9. jaxlib version is 0.4.9. It was run on M2 MacBook using Jupyter lab. What should I do to make it work?

I upgraded the version of packages, and it didn't work. I have no idea on why it keep crashing.


Solution

  • It looks like the jaxlib 0.4.9 release is broken on Mac ARM (see JAX Issue #15951). For now I'd recommend installing jax/jaxlib version 0.4.8 to fix the issue.

    Update: jax & jaxlib 0.4.10 have been released, and should fix this issue.