My jax code runs fine but when I try to insert a breakpoint with jax.debug.breakpoint I get the error: jax.errors.UnexpectedTracerError.
I would expect this error to show up also without setting a breakpoint.
Is this intended behavior or is something weird happening? When using jax_checking_leaks none of the reported tracers seem to actually be leaked.
There is currently a bug in jax.debug.breakpoint
that can lead to spurious tracer leaks in some situations: see https://github.com/google/jax/issues/16732.
There's not any easy workaround at the moment, unfortunately, but hopefully the issue will be addressed soon.