Search code examples
pythonnumpypicklejaxflax

Pickle changes type in jax


I have a flax struct dataclass containing a jax numpy array.

When I pickle dump this object and load it again, the array is not anymore a jax numpy array and is converted to a numpy array, here is the code to reproduce it:

import flax
import jax.numpy as jnp
import pickle

@flax.struct.dataclass
class A:
    data: jnp.ndarray

a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))



with open('file.pickle', 'wb') as handle:
    pickle.dump(a, handle)
    
with open('file.pickle', 'rb') as handle:
    loaded_a = pickle.load(handle)

print(loaded_a, type(loaded_a.data))

I don't want this behavior and I'd like it to keep its original type, is it possible ?


Solution

  • Update: this bug has been fixed in https://github.com/google/jax/pull/10659. Starting in the next release of JAX (v. 0.3.14) pickle and deepcopy should no longer convert JAX arrays to device arrays.


    This is a known behavior in JAX; see https://github.com/google/jax/issues/2632

    It's something that the library developers recognize as an unfortunate behavior, but a fix has not yet been prioritized. If you're interested, you might weigh-in on that issue.