The kernel dies with jax.random.PGRNKey. python version is 3.10.10. jax version is 0.4.9. jaxlib version is 0.4.9. It was run on M2 MacBook using Jupyter lab. What should I do to make it work?
I upgraded the version of packages, and it didn't work. I have no idea on why it keep crashing.
It looks like the jaxlib 0.4.9 release is broken on Mac ARM (see JAX Issue #15951). For now I'd recommend installing jax/jaxlib version 0.4.8 to fix the issue.
Update: jax & jaxlib 0.4.10 have been released, and should fix this issue.