Search code examples
pythongoogle-colaboratoryjaxstable-diffusion

Stable diffusion: AttributeError: module 'jax.random' has no attribute 'KeyArray'


When I run the stable diffusion on colab https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
with no modification, it fails on the line

from diffusers import StableDiffusionPipeline

The error log is

AttributeError: module 'jax.random' has no attribute 'KeyArray'

How can I fix this or any clue ?

The import should work, the ipynb should run with no error.


Solution

  • jax.random.KeyArray was deprecated in JAX v0.4.16 and removed in JAX v0.4.24. Given this, it sounds like the HuggingFace stable diffusion code only works JAX v0.4.23 or earlier.

    You can install JAX v0.4.23 with GPU support like this:

    pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    

    or, if you prefer targeting a local CUDA installation, like this:

    pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    

    For more information on GPU installation, see JAX Installation: NVIDIA GPU.

    From the colab tutorial, update the second segment into:

    !pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    !pip install diffusers==0.11.1
    !pip install transformers scipy ftfy accelerate