Search code examples
pythonneural-networkjax

Jax and train Neural Networks


I am a beginner in JAX and I am trying to learn how to train a neural network. I saw some blogs, but as I understood there isn't a library that you can train it easily, like 'fit' as in sklearn. I am interested about classification task, could you please reccommend me any blogs in order to adopt his/her algorithm into my problem?


Solution

  • JAX is an array manipulation library, not a deep learning library: in that respect, you should think of it as more similar to NumPy than similar to scikit-learn. If you want neural networks built on JAX, there are several other good projects available such as haiku and flax.