Search code examples
pythontensorflowjax

Is there a module to convert a tensorflow NN to Jax?


There is a libary to convert Jax functions to Tensorflow functions. Is there a similar library to convert TensorFlow functions to Jax functions?


Solution

  • No, there is no library supported by the JAX team to convert tensorflow into JAX in a manner similar to how jax.experimental.jax2tf converts JAX code to tensorflow, and I have not seen any such library developed by others.