Search code examples
python-3.xtpujaxspacy-transformerssentence-transformers

Not able to import python package jax in Google TPU


I am working on linux console and typing python takes me into the python console. When I use the following command in TPU machine

import jax

then it generates following mss and get out of the python prompt.

paramjeetsingh80@t1v-n-1c883486-w-0:~$ python3
Python 3.8.5 (default, Jan 27 2021, 15:41:15)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2021-07-08 17:41:39.660523: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped)
paramjeetsingh80@t1v-n-1c883486-w-0:~$

This issue is causing problem in my code so I would like to figure out, what is this issue and how to get rid of this?


Solution

  • It may be that your system does not have the correct version of libtpu. Try installing the version listed here.

    You should be able to do this automatically with

    $ pip install -U pip  # older pip may not support extra requirements
    $ pip install -U jax  # newer jax required for [tpu] extras declaration
    $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html