Search code examples
pythonjaxflax

AttributeError: module 'flax' has no attribute 'nn'


I'm trying to run RegNeRF, which requires flax. On installing the latest version of flax==0.6.0, I got an error stating flax has no attribute optim. This answer suggested to downgrade flax to 0.5.1. On doing that, now I'm getting the error AttributeError: module 'flax' has no attribute 'nn'

I could not find any solutions on the web for this error. Any help is appreciated.

I'm using ubuntu 20.04


Solution

  • The flax.optim module has been moved to optax as of flax version 0.6.0; see Upgrading my Codebase to Optax for information on how to migrate your code. If you're using external code that imports flax.optim and can't update these references, you'll have to install flax version 0.5.3 or older.

    Regarding flax.nn: this module was replaced by flax.linen in flax version 0.4.0. See Upgrading my Codebase to Linen for information on this migration. If you're using external code that imports flax.nn and can't update these references, you'll have to install flax version 0.3.6 or older.