I'm having trouble getting PyTorch 2.2 running with TPUs on Google Colab. I'm getting an error about a JAX bug, but I'm confused about this because I'm not doing anything with JAX.
My setup process is very simple:
!pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
And then
import torch
import torch_xla.core.xla_model as xm
which gives the error
/usr/local/lib/python3.10/dist-packages/jax/__init__.py:27: UserWarning: cloud_tpu_init failed: KeyError('')
This a JAX bug; please report an issue at https://github.com/google/jax/issues
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
Then trying
t1 = torch.tensor(100, device=xm.xla_device())
t2 = torch.tensor(200, device=xm.xla_device())
print(t1 + t2)
gives the error
2 frames
/usr/local/lib/python3.10/dist-packages/torch_xla/runtime.py in xla_device(n, devkind)
122 if n is None:
--> 123 return torch.device(torch_xla._XLAC._xla_get_default_device())
125 devices = xm.get_xla_supported_devices(devkind=devkind)
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: No ba16c7433 device found.
Colab currently only provides an older generation of TPUs which is not compatible with recent JAX or PyTorch releases. It’s possible that may change in the future, but I don’t know of any official timeline of when that might happen. In the meantime, you can access recent-generation TPUs via Kaggle or Google Cloud.